/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.dllib.keras.layers

import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity
import com.intel.analytics.bigdl.dllib.nn.internal.KerasLayer
import com.intel.analytics.bigdl.dllib.tensor.Tensor
import com.intel.analytics.bigdl.dllib.utils.{Shape, T, TestUtils}
import com.intel.analytics.bigdl.dllib.keras.autograd.Variable
import com.intel.analytics.bigdl.dllib.keras.Model
import com.intel.analytics.bigdl.dllib.keras.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.dllib.keras.ZooSpecHelper

class BertSpec extends ZooSpecHelper {
  "Bert " should "be able to work" in {
    val layer = BERT[Float](vocab = 100,
    hiddenSize = 10,
    nBlock = 3,
    nHead = 2,
    intermediateSize = 64,
    hiddenPDrop = 0.1,
    attnPDrop = 0.1,
    maxPositionLen = 10,
    outputAllBlock = false,
    inputSeqLen = 10)

    val shape = Shape(List(Shape(1, 10), Shape(1, 10), Shape(1, 10), Shape(1, 1, 1, 10)))
    layer.build(shape)
    val w = layer.parameters()._1
//    TestUtils.conditionFailTest(w.length == 43)
    val inputIds = Tensor[Float](Array[Float](7, 20, 39, 27, 10,
      39, 30, 21, 17, 15), Array(1, 10))
    val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1), Array(1, 10))
    val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(1, 10))
    val masks = Tensor[Float](1, 1, 1, 10).fill(1.0f)

    val output = layer.forward(T(inputIds, segmentIds, positionIds, masks))
    val gradOutput = T(Tensor[Float](1, 10, 10).rand(), Tensor[Float](1, 10).rand())
    val gradInput = layer.backward(T(inputIds, segmentIds, positionIds, masks), gradOutput)
  }

  "Bert with output all blocks " should "be able to work" in {
    val layer = BERT[Float](vocab = 100,
      hiddenSize = 10,
      nBlock = 3,
      nHead = 2,
      intermediateSize = 64,
      hiddenPDrop = 0.1,
      attnPDrop = 0.1,
      maxPositionLen = 10,
      outputAllBlock = true)

    val shape = Shape(List(Shape(1, 10), Shape(1, 10), Shape(1, 10), Shape(1, 1, 1, 10)))
    layer.build(shape)
    val w = layer.parameters()._1
//    TestUtils.conditionFailTest(w.length == 43)
    val inputIds = Tensor[Float](Array[Float](7, 20, 39, 27, 10,
      39, 30, 21, 17, 15), Array(1, 10))
    val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1), Array(1, 10))
    val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(1, 10))
    val masks = Tensor[Float](1, 1, 1, 10).fill(1.0f)

    val output = layer.forward(T(inputIds, segmentIds, positionIds, masks))
    val gradInput = layer.backward(T(inputIds, segmentIds, positionIds, masks), output)
  }

  "Bert with configed embedding" should "be able to work" in {
    val vocab = 100
    val seqLen = 10
    val hiddenPDrop = 0.1
    val hiddenSize = 10
    val wordInput = Variable[Float](Shape(seqLen))
    val tokenTypeInput = Variable[Float](Shape(seqLen))
    val positionInput = Variable[Float](Shape(seqLen))

    val wordEmbeddings = Embedding[Float](vocab, hiddenSize, inputLength = seqLen)
      .from[Float](wordInput)
    val positionEmbeddings = Embedding[Float](seqLen, hiddenSize, inputLength = seqLen)
      .from[Float](positionInput)
    val tokenTypeEmbeddings = Embedding[Float](2, hiddenSize, inputLength = seqLen)
      .from[Float](tokenTypeInput)

    val embeddings =
      wordEmbeddings.asInstanceOf[Variable[Float]] + positionEmbeddings + tokenTypeEmbeddings
    val afterNorm = LayerNorm[Float](nOutput = hiddenSize, eps = 1e-12).from(embeddings)
    val h = Dropout[Float](hiddenPDrop).from(afterNorm)

    val embeddingLayer = Model(Array(wordInput, tokenTypeInput, positionInput), h)

    val layer = BERT[Float](nBlock = 3, nHead = 2,
      intermediateSize = 64, hiddenPDrop = 0.1, attnPDrop = 0.1, initializerRange = 0.02,
      outputAllBlock = false, embeddingLayer
        .asInstanceOf[KerasLayer[Activity, Tensor[Float], Float]])
    val shape = Shape(List(Shape(1, 10), Shape(1, 10), Shape(1, 10), Shape(1, 1, 1, 10)))
    layer.build(shape)

    val inputIds = Tensor[Float](Array[Float](7, 20, 39, 27, 10,
      39, 30, 21, 17, 15), Array(1, 10))
    val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1), Array(1, 10))
    val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9), Array(1, 10))
    val masks = Tensor[Float](1, 1, 1, 10).fill(1.0f)

    val output = layer.forward(T(inputIds, segmentIds, positionIds, masks))
    val gradInput = layer.backward(T(inputIds, segmentIds, positionIds, masks), output)
  }

  "BERT" should "be able to generate correct result" in {
    val layer = BERT[Float](vocab = 20,
      hiddenSize = 10,
      nBlock = 1,
      nHead = 2,
      intermediateSize = 64,
      hiddenPDrop = 0,
      attnPDrop = 0,
      maxPositionLen = 6,
      outputAllBlock = false)
    val shape = Shape(List(Shape(2, 6), Shape(2, 6), Shape(2, 6), Shape(2, 1, 1, 6)))
    layer.build(shape)

    val wb = layer.parameters()._1
    val data = Array[Float](6f, 19f, 14f, 10f, 7f, 6f,
      18f, 10f, 10f, 3f, 7f, 2f)

    val wordEmbedding = Tensor[Float](Array[Float](-0.0207f, -0.0060f, -0.0170f, 0.0058f,
      -0.0179f, 0.0373f, 0.0010f, 0.0004f, -0.0026f, 0.0072f,
      -0.0225f, 0.0093f, -0.0610f, 0.0170f, 0.0116f, 0.0049f, -0.0276f, -0.0073f,
      -0.0040f, 0.0111f, 0.0053f, -0.0054f, 0.0242f, -0.0171f, 0.0179f, 0.0219f,
      0.0001f, 0.0084f, 0.0149f, 0.0004f, -0.0089f, 0.0248f, 0.0279f, -0.0002f,
      0.0009f, 0.0044f, 0.0435f, 0.0439f, -0.0129f, 0.0336f,
      -0.0180f, -0.0170f, -0.0191f, 0.0007f, 0.0084f, 0.0393f, -0.0193f, 0.0152f,
      -0.0064f, 0.0113f, 0.0051f, -0.0221f, 0.0165f, 0.0001f, -0.0345f, 0.0046f,
      0.0078f, -0.0105f, -0.0024f, 0.0083f,
      0.0202f, 0.0136f, -0.0146f, -0.0166f, -0.0314f, 0.0129f, 0.0035f, 0.0060f,
      -0.0120f, 0.0222f,
      -0.0054f, -0.0088f, -0.0108f, 0.0098f, 0.0031f, 0.0040f, 0.0137f, -0.0085f,
      -0.0151f, 0.0034f,
      -0.0229f, -0.0127f, 0.0236f, 0.0075f, 0.0033f, -0.0116f, -0.0122f, -0.0041f,
      0.0147f, 0.0283f,
      -0.0014f, -0.0050f, -0.0126f, -0.0021f, 0.0197f, -0.0042f, -0.0077f, -0.0154f,
      -0.0201f, -0.0193f,
      0.0042f, -0.0283f, 0.0178f, -0.0061f, -0.0087f, 0.0059f, 0.0023f, -0.0036f,
      -0.0019f, -0.0127f,
      0.0403f, 0.0129f, -0.0210f, 0.0217f, -0.0181f, -0.0208f, 0.0076f, -0.0036f,
      -0.0135f, -0.0150f,
      -0.0255f, -0.0196f, 0.0156f, -0.0048f, -0.0153f, -0.0173f, -0.0058f, -0.0217f,
      0.0012f, 0.0070f,
      0.0219f, -0.0198f, 0.0400f, -0.0260f, -0.0180f, -0.0484f, 0.0244f, -0.0031f,
      -0.0102f, 0.0024f,
      0.0286f, 0.0417f, 0.0201f, -0.0025f, -0.0059f, 0.0126f, 0.0159f, 0.0139f,
      -0.0037f, 0.0040f,
      0.0041f, -0.0145f, -0.0018f, -0.0237f, -0.0047f, 0.0134f, -0.0144f, 0.0196f,
      0.0035f, 0.0020f,
      0.0104f, 0.0021f, -0.0014f, 0.0037f, -0.0227f, -0.0122f, 0.0095f, 0.0045f,
      0.0092f, 0.0073f,
      0.0050f, -0.0114f, -0.0377f, 0.0325f, -0.0141f, 0.0112f, 0.0084f, 0.0259f,
      -0.0012f, -0.0015f,
      0.0363f, 0.0083f, -0.0108f, -0.0134f, -0.0076f, -0.0327f, -0.0065f, 0.0422f,
      -0.0112f, 0.0078f,
      -0.0455f, 0.0075f, 0.0041f, -0.0203f, -0.0120f, 0.0036f, -0.0229f, 0.0268f,
      -0.0199f, -0.0120f),
      Array(20, 10))
    wb(1).set(wordEmbedding)

    val positionEmbedding = Tensor[Float](Array[Float](0.0176f, 0.0146f, 0.0055f, 0.0344f,
      0.0269f, -0.0253f, 0.0090f, 0.0132f, -0.0130f, -0.0178f,
      -0.0216f, 0.0285f, -0.0224f, 0.0078f, -0.0098f, -0.0058f, -0.0041f, -0.0039f,
      0.0334f, 0.0135f,
      0.0234f, 0.0004f, -0.0183f, 0.0044f, 0.0251f, -0.0162f, 0.0412f, 0.0061f,
      -0.0292f, -0.0041f,
      0.0006f, 0.0058f, 0.0084f, -0.0163f, -0.0139f, 0.0180f, -0.0153f, -0.0039f,
      -0.0073f, -0.0311f,
      -0.0172f, -0.0080f, 0.0336f, -0.0022f, -0.0219f, 0.0047f, 0.0117f, -0.0457f,
      0.0117f, 0.0255f,
      0.0176f, 0.0104f, 0.0091f, -0.0187f, -0.0011f, 0.0191f, -0.0137f, 0.0205f,
      0.0346f, -0.0065f),
      Array(6, 10))
    wb(0).set(positionEmbedding)

    val tokenTypeEmbedding = Tensor[Float](Array[Float](-0.0218f, -0.0038f, 0.0163f,
      -0.0165f, 0.0279f, -0.0080f, -0.0095f, 0.0120f, -0.0028f, -0.0104f,
      -0.0086f, -0.0187f, -0.0065f, 0.0186f, -0.0057f, 0.0169f, 0.0004f, -0.0335f,
      -0.0389f, 0.0020f), Array(2, 10))
    wb(2).set(tokenTypeEmbedding)

    val queryW = Tensor[Float](Array[Float](-0.0135f, -0.0378f, -0.0368f, 0.0026f,
      -0.0159f, 0.0246f, 0.0016f, 0.0361f, -0.0068f, -0.0093f,
      -0.0080f, -0.0262f, 0.0006f, -0.0118f, -0.0235f, 0.0348f, -0.0059f, -0.0069f,
      -0.0099f, -0.0260f, 0.0262f, -0.0053f, 0.0039f, -0.0140f, 0.0228f, 0.0038f,
      -0.0002f, 0.0071f, -0.0085f, 0.0214f,
      0.0542f, -0.0039f, 0.0350f, -0.0022f, -0.0164f, 0.0160f, -0.0154f, 0.0308f,
      -0.0355f, -0.0213f,
      0.0210f, 0.0277f, -0.0301f, -0.0217f, 0.0430f, -0.0185f, -0.0172f, -0.0003f,
      0.0195f, -0.0015f,
      -0.0434f, 0.0243f, -0.0362f, 0.0039f, 0.0134f, -0.0232f, -0.0143f, -0.0205f,
      -0.0296f, 0.0009f,
      -0.0021f, 0.0071f, 0.0066f, -0.0106f, 0.0007f, 0.0493f, -0.0033f, -0.0061f,
      0.0284f, -0.0091f,
      -0.0320f, 0.0155f, -0.0127f, -0.0050f, 0.0140f, 0.0288f, -0.0214f, -0.0033f,
      0.0104f, -0.0147f,
      0.0067f, -0.0152f, 0.0011f, -0.0301f, -0.0118f, -0.0033f, -0.0099f, 0.0101f,
      0.0200f, 0.0007f,
      0.0046f, -0.0397f, -0.0036f, 0.0179f, 0.0010f, -0.0106f, 0.0115f, 0.0077f,
      0.0054f, 0.0092f), Array(10, 10))
    wb(5).set(queryW)
//    wb(5).narrow(1, 1, 10).copy(queryW)

    val queryB = Tensor[Float](10)
    wb(6).set(queryB)

    val keyW = Tensor[Float](Array[Float](-0.0040f, 0.0094f, 0.0217f, -0.0098f, 0.0350f,
      -0.0075f, 0.0190f, 0.0261f, 0.0226f, -0.0024f,
      0.0321f, 0.0017f, 0.0387f, -0.0089f, -0.0095f, -0.0061f, 0.0248f, -0.0131f,
      0.0370f, -0.0239f,
      -0.0085f, -0.0146f, -0.0080f, -0.0001f, -0.0064f, -0.0075f, 0.0452f, -0.0066f,
      0.0372f, -0.0023f,
      -0.0136f, 0.0084f, 0.0180f, 0.0127f, -0.0170f, -0.0019f, -0.0242f, 0.0634f,
      0.0460f, -0.0028f,
      -0.0196f, -0.0011f, -0.0209f, -0.0032f, -0.0105f, -0.0239f, -0.0218f, 0.0302f,
      -0.0002f, 0.0009f,
      -0.0071f, 0.0032f, 0.0301f, 0.0172f, 0.0113f, -0.0277f, 0.0229f, -0.0151f,
      -0.0293f, 0.0016f,
      -0.0002f, -0.0080f, -0.0280f, 0.0092f, -0.0077f, 0.0132f, 0.0152f, -0.0129f,
      0.0320f, 0.0072f,
      0.0012f, -0.0214f, 0.0291f, -0.0059f, -0.0094f, -0.0300f, -0.0004f, -0.0301f,
      -0.0042f, 0.0388f,
      0.0249f, 0.0158f, 0.0310f, -0.0122f, -0.0155f, -0.0079f, 0.0113f, 0.0061f,
      -0.0030f, -0.0095f,
      0.0126f, -0.0046f, 0.0042f, -0.0291f, 0.0046f, -0.0052f, 0.0210f, 0.0109f,
      0.0345f, 0.0108f),
      Array(10, 10))
    wb(7).set(keyW)
//    wb(5).narrow(1, 11, 10).copy(keyW)

    val keyB = Tensor[Float](10)
    wb(8).set(keyB)

    val valueW = Tensor[Float](Array[Float](0.0014f, 0.0066f, 0.0252f, 0.0175f,
      -0.0458f, 0.0379f, -0.0464f, -0.0128f, -0.0171f, -0.0013f,
      -0.0159f, 0.0156f, 0.0389f, 0.0042f, -0.0055f, 0.0259f, 0.0276f, 0.0018f,
      -0.0041f, 0.0081f,
      -0.0032f, 0.0056f, -0.0088f, 0.0002f, -0.0376f, 0.0263f, 0.0234f, -0.0073f,
      -0.0090f, 0.0257f,
      -0.0124f, -0.0159f, 0.0061f, 0.0145f, -0.0018f, -0.0226f, 0.0045f, 0.0106f,
      0.0309f, -0.0070f,
      0.0030f, -0.0226f, 0.0275f, -0.0034f, -0.0555f, 0.0061f, 0.0150f, -0.0128f,
      -0.0415f, 0.0132f,
      0.0353f, -0.0234f, -0.0026f, -0.0120f, -0.0281f, -0.0058f, -0.0219f, -0.0173f,
      0.0167f, -0.0188f,
      0.0286f, 0.0040f, -0.0468f, -0.0318f, 0.0170f, -0.0019f, -0.0026f, -0.0007f,
      0.0001f, 0.0030f,
      -0.0291f, 0.0254f, 0.0218f, -0.0096f, -0.0126f, 0.0238f, -0.0207f, 0.0077f,
      -0.0555f, -0.0040f,
      -0.0111f, 0.0067f, 0.0196f, 0.0225f, -0.0029f, -0.0195f, 0.0125f, -0.0069f,
      -0.0345f, -0.0198f,
      0.0391f, -0.0013f, 0.0171f, -0.0002f, 0.0223f, -0.0236f, 0.0137f, 0.0313f,
      0.0203f, -0.0230f), Array(10, 10))
    wb(9).set(valueW)
//    wb(5).narrow(1, 21, 10).copy(valueW)

    val valueB = Tensor[Float](10)
    wb(10).set(valueB)
//    wb(6).set(Tensor[Float](30))

    val outputDenseW = Tensor[Float](Array[Float](0.0080f, 0.0329f, -0.0080f,
      -0.0139f, 0.0116f, -0.0004f, -0.0028f, 0.0020f, 0.0338f, -0.0377f,
      -0.0157f, -0.0255f, -0.0213f, -0.0127f, -0.0126f, -0.0335f, 0.0091f, 0.0052f,
      0.0100f, -0.0095f,
      0.0333f, -0.0291f, -0.0046f, 0.0111f, 0.0137f, -0.0137f, 0.0179f, 0.0333f,
      -0.0203f, -0.0278f,
      0.0177f, 0.0183f, -0.0021f, 0.0426f, 0.0049f, 0.0218f, 0.0154f, 0.0003f,
      0.0136f, -0.0026f,
      -0.0115f, -0.0074f, -0.0034f, 0.0409f, -0.0400f, -0.0246f, 0.0175f, -0.0418f,
      -0.0114f, 0.0070f,
      -0.0236f, 0.0197f, 0.0158f, -0.0062f, 0.0135f, 0.0536f, 0.0095f, 0.0115f,
      0.0038f, 0.0086f,
      0.0106f, 0.0106f, -0.0317f, 0.0484f, 0.0078f, -0.0049f, -0.0054f, 0.0027f,
      0.0095f, 0.0006f,
      -0.0030f, -0.0073f, -0.0194f, -0.0140f, -0.0018f, -0.0196f, 0.0205f, -0.0207f,
      -0.0400f, 0.0302f,
      0.0036f, 0.0057f, 0.0072f, -0.0207f, -0.0128f, 0.0107f, 0.0034f, -0.0216f,
      0.0173f, 0.0196f,
      0.0145f, 0.0027f, -0.0126f, -0.0029f, -0.0085f, -0.0298f, 0.0172f, 0.0607f,
      -0.0230f, 0.0045f),
      Array(10, 10))
    wb(11).set(outputDenseW)
//    wb(7).set(outputDenseW)

    val outputDenseB = Tensor[Float](10)
    wb(12).set(outputDenseB)
//    wb(8).set(outputDenseB)

    val intermediateDenseW = Tensor[Float](Array[Float](0.0006f, 0.0003f, 0.0235f,
      -0.0193f, -0.0049f, -0.0136f, -0.0202f, -0.0078f, -0.0276f, 0.0214f,
      -0.0181f, 0.0154f, 0.0088f, -0.0101f, 0.0423f, 0.0244f, -0.0157f, 0.0218f,
      -0.0013f, 0.0251f,
      0.0032f, -0.0349f, -0.0259f, 0.0261f, 0.0142f, 0.0059f, -0.0139f, -0.0160f,
      -0.0016f, -0.0100f,
      -0.0445f, -0.0035f, -0.0133f, -0.0110f, 0.0012f, 0.0308f, 0.0209f, -0.0053f,
      0.0044f, 0.0010f,
      0.0225f, 0.0109f, -0.0044f, 0.0082f, -0.0226f, -0.0478f, 0.0144f, -0.0317f,
      -0.0193f, -0.0211f,
      -0.0122f, 0.0022f, 0.0025f, -0.0288f, -0.0092f, 0.0144f, -0.0019f, -0.0136f,
      0.0147f, 0.0019f,
      0.0217f, 0.0162f, -0.0195f, -0.0052f, 0.0180f, 0.0064f, 0.0301f, -0.0000f,
      -0.0168f, -0.0199f,
      0.0394f, -0.0125f, 0.0156f, -0.0295f, 0.0183f, -0.0163f, -0.0066f, -0.0321f,
      0.0031f, 0.0248f,
      -0.0268f, -0.0021f, 0.0031f, -0.0303f, 0.0198f, 0.0111f, -0.0136f, 0.0194f,
      0.0167f, -0.0415f,
      0.0185f, 0.0376f, 0.0006f, -0.0073f, 0.0091f, 0.0152f, -0.0193f, 0.0191f,
      -0.0282f, 0.0163f,
      0.0287f, 0.0012f, -0.0179f, -0.0017f, -0.0121f, -0.0138f, 0.0041f, -0.0144f,
      -0.0229f, 0.0178f,
      0.0050f, 0.0195f, -0.0201f, -0.0174f, 0.0207f, 0.0228f, -0.0122f, 0.0113f,
      -0.0031f, -0.0058f,
      -0.0241f, -0.0034f, -0.0204f, 0.0091f, 0.0011f, 0.0038f, -0.0001f, -0.0128f,
      -0.0047f, 0.0037f,
      0.0165f, 0.0136f, 0.0081f, 0.0341f, 0.0233f, -0.0048f, 0.0045f, -0.0475f,
      0.0080f, -0.0459f,
      0.0191f, -0.0078f, 0.0439f, 0.0168f, -0.0298f, 0.0117f, -0.0128f, -0.0381f,
      -0.0043f, 0.0033f,
      0.0017f, -0.0076f, -0.0277f, 0.0113f, -0.0441f, 0.0057f, 0.0500f, 0.0011f,
      -0.0237f, 0.0164f,
      0.0160f, 0.0069f, -0.0142f, 0.0081f, 0.0191f, 0.0061f, 0.0064f, -0.0366f,
      0.0370f, -0.0258f,
      0.0253f, -0.0194f, -0.0092f, 0.0168f, -0.0039f, -0.0028f, 0.0014f, -0.0014f,
      0.0254f, 0.0458f,
      -0.0016f, -0.0429f, 0.0319f, 0.0508f, -0.0099f, 0.0312f, -0.0240f, -0.0127f,
      0.0041f, 0.0242f,
      -0.0260f, -0.0357f, -0.0348f, 0.0347f, 0.0147f, 0.0199f, 0.0002f, -0.0057f,
      -0.0214f, -0.0079f,
      0.0145f, 0.0227f, 0.0027f, -0.0057f, 0.0169f, -0.0333f, 0.0180f, 0.0058f,
      0.0197f, -0.0330f,
      -0.0169f, -0.0356f, 0.0336f, -0.0062f, 0.0036f, 0.0061f, 0.0129f, 0.0260f,
      0.0129f, -0.0264f,
      -0.0114f, -0.0380f, 0.0050f, 0.0145f, -0.0002f, 0.0065f, -0.0080f, -0.0109f,
      -0.0190f, -0.0117f,
      0.0122f, 0.0372f, -0.0410f, 0.0244f, 0.0275f, 0.0099f, -0.0101f, 0.0262f,
      -0.0017f, 0.0043f,
      0.0022f, -0.0100f, 0.0282f, 0.0182f, 0.0167f, -0.0091f, -0.0029f, 0.0022f,
      0.0138f, 0.0087f,
      0.0033f, 0.0186f, 0.0086f, -0.0448f, 0.0070f, -0.0300f, 0.0058f, -0.0057f,
      -0.0150f, -0.0285f,
      -0.0065f, 0.0050f, -0.0281f, 0.0001f, -0.0132f, -0.0186f, 0.0208f, 0.0346f,
      -0.0179f, -0.0350f,
      0.0048f, -0.0361f, -0.0083f, -0.0342f, 0.0073f, -0.0116f, -0.0142f, -0.0055f,
      0.0174f, 0.0043f,
      0.0104f, 0.0068f, 0.0202f, -0.0009f, 0.0060f, -0.0084f, -0.0208f, 0.0128f,
      0.0005f, -0.0207f,
      0.0138f, -0.0136f, -0.0185f, -0.0062f, 0.0039f, 0.0126f, -0.0345f, -0.0104f,
      0.0012f, 0.0084f,
      0.0212f, 0.0265f, 0.0227f, 0.0085f, 0.0155f, 0.0028f, -0.0056f, 0.0003f,
      -0.0181f, -0.0121f,
      0.0010f, -0.0120f, -0.0606f, 0.0052f, 0.0514f, -0.0095f, 0.0072f, -0.0116f,
      -0.0145f, 0.0276f,
      -0.0031f, -0.0277f, 0.0091f, 0.0078f, -0.0133f, -0.0108f, 0.0586f, 0.0008f,
      -0.0139f, -0.0124f,
      0.0149f, -0.0093f, 0.0071f, 0.0115f, -0.0385f, 0.0269f, 0.0243f, -0.0017f,
      -0.0151f, -0.0122f,
      0.0013f, 0.0041f, -0.0097f, -0.0099f, -0.0121f, -0.0077f, -0.0300f, 0.0285f,
      0.0085f, 0.0345f,
      -0.0235f, -0.0111f, 0.0125f, -0.0004f, -0.0348f, -0.0153f, 0.0286f, 0.0078f,
      -0.0279f, -0.0400f,
      0.0048f, 0.0498f, 0.0063f, -0.0078f, 0.0179f, 0.0145f, 0.0087f, 0.0195f,
      -0.0001f, 0.0015f,
      -0.0052f, 0.0159f, 0.0116f, -0.0027f, -0.0271f, -0.0079f, 0.0274f, 0.0257f,
      -0.0267f, 0.0189f,
      -0.0228f, -0.0224f, -0.0132f, 0.0376f, -0.0181f, -0.0002f, -0.0061f, -0.0406f,
      0.0057f, 0.0077f,
      -0.0047f, -0.0245f, -0.0159f, 0.0094f, 0.0019f, 0.0343f, -0.0032f, -0.0034f,
      -0.0164f, 0.0005f,
      0.0106f, -0.0076f, 0.0283f, 0.0098f, 0.0183f, 0.0035f, 0.0038f, -0.0032f,
      -0.0394f, -0.0049f,
      -0.0136f, 0.0096f, -0.0378f, 0.0060f, 0.0302f, -0.0080f, 0.0072f, 0.0089f,
      0.0074f, 0.0035f,
      -0.0243f, 0.0389f, 0.0227f, -0.0195f, 0.0194f, 0.0032f, -0.0026f, 0.0189f,
      0.0150f, 0.0270f,
      0.0332f, -0.0136f, 0.0172f, 0.0214f, 0.0207f, -0.0218f, -0.0197f, 0.0037f,
      0.0313f, -0.0277f,
      -0.0091f, 0.0140f, 0.0109f, 0.0110f, 0.0048f, -0.0226f, -0.0058f, -0.0289f,
      0.0066f, -0.0045f,
      0.0127f, 0.0254f, 0.0105f, -0.0018f, 0.0017f, -0.0110f, -0.0091f, 0.0212f,
      -0.0350f, 0.0166f,
      -0.0040f, -0.0245f, -0.0033f, -0.0127f, -0.0165f, -0.0073f, 0.0115f, -0.0108f,
      0.0429f, -0.0169f,
      0.0201f, 0.0020f, -0.0185f, 0.0134f, 0.0084f, -0.0044f, 0.0108f, 0.0054f,
      0.0131f, -0.0045f,
      -0.0054f, 0.0227f, 0.0257f, 0.0080f, 0.0311f, 0.0408f, -0.0023f, -0.0201f,
      -0.0076f, -0.0136f,
      -0.0211f, 0.0045f, 0.0205f, 0.0155f, 0.0129f, 0.0148f, -0.0150f, -0.0227f,
      0.0086f, -0.0045f,
      0.0008f, 0.0071f, 0.0388f, -0.0251f, 0.0092f, 0.0300f, 0.0078f, 0.0081f,
      -0.0096f, -0.0115f,
      -0.0465f, -0.0057f, -0.0101f, 0.0116f, -0.0535f, 0.0037f, -0.0263f, -0.0155f,
      -0.0019f, -0.0234f,
      0.0102f, 0.0069f, -0.0234f, 0.0051f, 0.0090f, -0.0155f, -0.0516f, 0.0267f,
      0.0062f, 0.0016f,
      0.0027f, 0.0255f, 0.0085f, 0.0024f, 0.0111f, 0.0061f, -0.0102f, -0.0077f,
      -0.0037f, -0.0033f,
      0.0101f, 0.0240f, 0.0110f, -0.0102f, 0.0328f, -0.0274f, -0.0027f, 0.0122f,
      -0.0555f, 0.0087f,
      0.0025f, 0.0304f, 0.0137f, 0.0337f, -0.0111f, -0.0078f, 0.0207f, 0.0033f,
      -0.0118f, -0.0003f,
      -0.0076f, -0.0052f, -0.0105f, -0.0085f, 0.0015f, 0.0321f, 0.0254f, 0.0038f,
      0.0072f, -0.0222f,
      0.0031f, 0.0124f, 0.0297f, -0.0053f, 0.0184f, 0.0030f, -0.0395f, -0.0132f,
      -0.0135f, 0.0015f,
      -0.0079f, -0.0080f, -0.0373f, 0.0206f, -0.0060f, -0.0010f, 0.0334f, 0.0142f,
      0.0197f, -0.0132f,
      0.0237f, -0.0328f, 0.0156f, -0.0173f, -0.0084f, 0.0183f, 0.0054f, 0.0004f,
      0.0042f, 0.0196f,
      -0.0209f, -0.0162f, 0.0099f, 0.0252f, 0.0012f, -0.0277f, -0.0299f, 0.0181f,
      -0.0158f, -0.0084f,
      0.0402f, -0.0134f, -0.0415f, -0.0018f, -0.0120f, -0.0050f, -0.0235f, -0.0153f,
      0.0204f, 0.0117f,
      0.0154f, -0.0392f, 0.0325f, -0.0381f, -0.0376f, 0.0295f, 0.0327f, -0.0193f,
      0.0228f, 0.0083f,
      0.0405f, -0.0203f, 0.0070f, 0.0116f, -0.0079f, 0.0091f, -0.0236f, 0.0158f,
      0.0228f, 0.0111f), Array(64, 10))
    wb(15).set(intermediateDenseW)
//    wb(11).set(intermediateDenseW)

    val intermediateDenseB = Tensor[Float](64)
    wb(16).set(intermediateDenseB)
//    wb(12).set(intermediateDenseB)

    val denseW = Tensor[Float](Array[Float](0.0024f, 0.0089f, -0.0195f, -0.0117f,
      -0.0310f, 0.0060f, -0.0069f, -0.0041f,
      -0.0089f, 0.0254f, -0.0135f, 0.0130f, -0.0195f, 0.0368f, -0.0001f, 0.0082f,
      -0.0144f, 0.0129f, -0.0204f, -0.0220f, 0.0183f, 0.0222f, -0.0418f, 0.0431f,
      0.0071f, 0.0083f, -0.0185f, -0.0300f, 0.0141f, 0.0074f, -0.0143f, -0.0032f,
      -0.0066f, 0.0128f, -0.0047f, 0.0038f, 0.0004f, -0.0048f, -0.0370f, 0.0054f,
      -0.0004f, 0.0084f, -0.0172f, 0.0058f, 0.0131f, 0.0420f, -0.0185f, -0.0092f,
      0.0021f, 0.0030f, 0.0423f, 0.0053f, -0.0223f, 0.0228f, -0.0230f, 0.0145f,
      -0.0164f, 0.0266f, -0.0060f, -0.0219f, 0.0120f, -0.0098f, -0.0039f, -0.0329f,
      -0.0014f, -0.0449f, 0.0316f, -0.0103f, -0.0040f, 0.0168f, -0.0071f, 0.0279f,
      -0.0078f, 0.0190f, -0.0430f, -0.0012f, 0.0052f, 0.0083f, -0.0110f, 0.0051f,
      0.0162f, -0.0106f, 0.0090f, -0.0160f, 0.0126f, 0.0279f, -0.0223f, 0.0029f,
      0.0170f, 0.0239f, 0.0109f, -0.0242f, 0.0238f, 0.0257f, -0.0290f, -0.0085f,
      0.0185f, 0.0005f, 0.0141f, 0.0212f, -0.0322f, 0.0091f, 0.0228f, 0.0048f,
      -0.0025f, 0.0011f, 0.0218f, -0.0260f, 0.0270f, -0.0235f, 0.0313f, -0.0046f,
      0.0225f, -0.0148f, 0.0046f, -0.0091f, 0.0159f, 0.0281f, -0.0175f, -0.0070f,
      -0.0492f, 0.0250f, 0.0430f, -0.0064f, 0.0119f, 0.0074f, -0.0302f, -0.0128f,
      0.0226f, 0.0085f, -0.0057f, 0.0074f, 0.0216f, 0.0073f, 0.0281f, 0.0219f,
      -0.0207f, 0.0196f, 0.0091f, 0.0054f, -0.0510f, 0.0014f, 0.0072f, 0.0376f,
      -0.0149f, -0.0296f, 0.0284f, -0.0399f, 0.0015f, -0.0115f, -0.0116f, -0.0292f,
      -0.0410f, -0.0226f, 0.0252f, -0.0114f, 0.0157f, 0.0082f, 0.0061f, -0.0046f,
      0.0048f, -0.0360f, 0.0069f, -0.0060f, 0.0047f, -0.0022f, -0.0265f, -0.0347f,
      -0.0053f, -0.0282f, -0.0149f, 0.0286f, 0.0383f, -0.0310f, -0.0286f, 0.0041f,
      -0.0032f, 0.0010f, 0.0128f, 0.0216f, -0.0242f, 0.0120f, 0.0015f, -0.0177f,
      0.0074f, 0.0295f, -0.0350f, -0.0006f, -0.0133f, -0.0297f, 0.0016f, -0.0109f,
      -0.0210f, -0.0130f, -0.0271f, -0.0266f, 0.0102f, -0.0292f, 0.0290f, 0.0293f,
      -0.0039f, -0.0015f, 0.0012f, -0.0123f, 0.0203f, 0.0119f, 0.0153f, 0.0400f,
      0.0078f, 0.0399f, -0.0127f, 0.0214f, 0.0232f, -0.0356f, 0.0162f, -0.0021f,
      -0.0005f, 0.0067f, -0.0055f, -0.0059f, -0.0103f, 0.0317f, 0.0279f, -0.0109f,
      0.0414f, -0.0036f, 0.0162f, -0.0096f, -0.0128f, 0.0236f, -0.0328f, 0.0178f,
      0.0045f, 0.0141f, 0.0132f, 0.0091f, -0.0279f, 0.0138f, 0.0049f, 0.0012f,
      -0.0112f, -0.0097f, 0.0095f, 0.0034f, 0.0169f, 0.0360f, 0.0286f, -0.0197f,
      0.0002f, 0.0061f, -0.0151f, 0.0147f, -0.0201f, 0.0042f, 0.0293f, -0.0217f,
      -0.0135f, -0.0050f, -0.0460f, 0.0308f, 0.0091f, 0.0032f, -0.0017f, 0.0053f,
      0.0099f, 0.0264f, 0.0091f, -0.0096f, 0.0596f, -0.0162f, -0.0155f, -0.0208f,
      0.0236f, -0.0138f, -0.0217f, 0.0156f, -0.0129f, -0.0283f, 0.0173f, -0.0075f,
      -0.0158f, -0.0106f, -0.0226f, -0.0018f, 0.0022f, 0.0440f, -0.0408f, 0.0079f,
      -0.0123f, 0.0128f, -0.0616f, -0.0272f, -0.0059f, 0.0226f, -0.0158f, -0.0008f,
      -0.0152f, 0.0109f, 0.0091f, 0.0182f, 0.0200f, -0.0054f, -0.0191f, -0.0135f,
      0.0062f, -0.0026f, 0.0007f, 0.0103f, 0.0183f, -0.0105f, -0.0038f, -0.0199f,
      -0.0106f, -0.0262f, -0.0283f, -0.0321f, -0.0033f, -0.0174f, -0.0081f, -0.0301f,
      0.0110f, -0.0254f, -0.0166f, -0.0069f, -0.0146f, -0.0113f, 0.0008f, -0.0096f,
      -0.0087f, 0.0118f, -0.0238f, 0.0062f, -0.0115f, 0.0065f, 0.0126f, 0.0376f,
      -0.0495f, 0.0261f, -0.0119f, 0.0071f, 0.0006f, -0.0024f, 0.0218f, -0.0158f,
      -0.0460f, 0.0048f, 0.0080f, -0.0245f, 0.0007f, 0.0236f, -0.0353f, -0.0200f,
      0.0311f, 0.0036f, -0.0223f, -0.0187f, -0.0084f, -0.0163f, 0.0003f, -0.0237f,
      0.0097f, -0.0038f, -0.0294f, -0.0068f, 0.0259f, -0.0401f, 0.0056f, -0.0162f,
      -0.0310f, 0.0010f, -0.0154f, 0.0128f, 0.0123f, -0.0041f, -0.0332f, 0.0176f,
      -0.0326f, -0.0103f, 0.0108f, -0.0245f, -0.0236f, -0.0234f, 0.0320f, -0.0077f,
      -0.0065f, 0.0223f, -0.0266f, -0.0250f, -0.0020f, 0.0250f, -0.0022f, -0.0281f,
      -0.0121f, 0.0187f, 0.0053f, -0.0302f, -0.0085f, -0.0090f, -0.0136f, -0.0192f,
      0.0088f, -0.0030f, -0.0073f, -0.0094f, -0.0147f, 0.0287f, 0.0004f, -0.0017f,
      0.0330f, -0.0036f, -0.0372f, 0.0194f, -0.0317f, 0.0010f, 0.0227f, -0.0391f,
      0.0113f, -0.0050f, 0.0289f, 0.0196f, -0.0264f, -0.0014f, 0.0054f, -0.0140f,
      0.0317f, 0.0028f, 0.0075f, -0.0158f, 0.0534f, -0.0028f, 0.0188f, -0.0002f,
      -0.0104f, 0.0370f, 0.0367f, 0.0415f, -0.0147f, -0.0154f, -0.0010f, 0.0320f,
      0.0042f, 0.0221f, 0.0262f, 0.0085f, -0.0098f, 0.0331f, 0.0082f, -0.0049f,
      0.0173f, -0.0290f, 0.0063f, -0.0201f, -0.0269f, 0.0254f, -0.0259f, -0.0148f,
      -0.0066f, 0.0066f, 0.0196f, -0.0298f, 0.0108f, 0.0267f, -0.0113f, 0.0133f,
      0.0086f, -0.0155f, -0.0161f, 0.0093f, -0.0034f, 0.0541f, 0.0132f, -0.0131f,
      0.0146f, 0.0019f, -0.0036f, 0.0140f, -0.0250f, 0.0182f, -0.0032f, -0.0088f,
      0.0148f, -0.0094f, -0.0319f, -0.0333f, 0.0068f, 0.0075f, -0.0266f, 0.0113f,
      0.0161f, 0.0149f, -0.0035f, 0.0222f, 0.0103f, 0.0179f, -0.0302f, -0.0170f,
      0.0416f, 0.0214f, -0.0286f, -0.0066f, 0.0341f, 0.0121f, 0.0222f, -0.0113f,
      -0.0273f, 0.0036f, 0.0113f, 0.0065f, 0.0003f, 0.0538f, 0.0255f, -0.0127f,
      0.0114f, 0.0102f, 0.0195f, 0.0395f, -0.0255f, -0.0157f, 0.0428f, 0.0069f,
      0.0190f, -0.0307f, -0.0034f, 0.0196f, -0.0012f, 0.0415f, -0.0080f, -0.0040f,
      -0.0033f, -0.0248f, -0.0211f, 0.0249f, -0.0145f, -0.0209f, 0.0113f, 0.0035f,
      0.0398f, -0.0236f, 0.0061f, 0.0445f, 0.0061f, -0.0126f, 0.0155f, -0.0102f,
      0.0060f, 0.0035f, 0.0017f, 0.0022f, 0.0213f, 0.0368f, -0.0114f, -0.0289f,
      -0.0051f, 0.0235f, -0.0086f, -0.0060f, 0.0030f, 0.0332f, -0.0524f, -0.0086f,
      -0.0020f, -0.0088f, -0.0396f, -0.0007f, -0.0157f, -0.0023f, 0.0380f, 0.0090f,
      -0.0117f, 0.0035f, -0.0041f, -0.0254f, 0.0175f, 0.0028f, -0.0281f, 0.0082f,
      0.0292f, -0.0106f, 0.0179f, -0.0228f, -0.0262f, -0.0019f, -0.0482f, -0.0002f,
      -0.0132f, 0.0161f, -0.0014f, -0.0303f, 0.0040f, -0.0049f, 0.0047f, 0.0152f,
      -0.0101f, 0.0150f, 0.0214f, -0.0527f, 0.0139f, -0.0065f, 0.0119f, -0.0125f,
      0.0329f, -0.0342f, 0.0019f, -0.0057f, 0.0203f, -0.0132f, -0.0080f, -0.0359f,
      -0.0076f, 0.0046f, -0.0146f, -0.0081f, 0.0241f, -0.0139f, -0.0303f, 0.0405f,
      -0.0162f, 0.0019f, 0.0055f, 0.0109f, -0.0495f, 0.0042f, 0.0051f, -0.0146f,
      -0.0276f, -0.0042f, 0.0178f, 0.0037f, 0.0101f, -0.0201f, -0.0268f, -0.0090f,
      -0.0032f, -0.0322f, 0.0036f, 0.0373f, -0.0240f, -0.0014f, 0.0029f, -0.0030f), Array(10, 64))
    wb(17).set(denseW)
//    wb(13).set(denseW)

    val denseB = Tensor[Float](10)
    wb(18).set(denseB)
//    wb(14).set(denseB)

    val input = Tensor[Float](data, Array(2, 6))
    val tokenTypeInput = Tensor[Float](2, 6)
    val positionInput = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5,
      0, 1, 2, 3, 4, 5), Array(2, 6))
    val attentionMask = Tensor[Float](2, 1, 1, 6).fill(1.0f)
    val finalInput = T(input, tokenTypeInput, positionInput,
      attentionMask)
    val output = layer.forward(finalInput).toTable

    val expect = Tensor[Float](Array[Float](0.5962f, 1.0420f, 0.1113f, -0.2090f, 0.9805f,
      -1.4064f, -0.1274f,
      1.4107f, -1.7832f, -0.6146f,
      -2.3127f, 1.2060f, 0.2066f, -0.5891f, 0.4503f, -0.0399f, -0.7912f,
      1.2796f, 0.5826f, 0.0078f,
      0.5818f, 0.8573f, 0.1493f, -1.0122f, 1.1666f, -0.9232f, 1.1834f,
      0.6349f, -1.7570f, -0.8808f,
      -0.2431f, -0.6013f, 1.9841f, -1.0812f, 0.5765f, 0.9691f, -0.4521f,
      0.5557f, -0.0692f, -1.6384f,
      -1.6054f, -0.6603f, 1.7060f, -0.2060f, 0.5141f, 0.1785f, 0.7932f,
      -1.5166f, -0.0973f, 0.8939f,
      0.4227f, 0.5839f, 0.2054f, -2.3533f, -0.4316f, 0.7340f, -1.0430f,
      1.3305f, 0.5629f, -0.0113f,
      0.7129f, 0.3546f, 0.1327f, -0.0403f, 1.1138f, -1.9698f, -0.3605f,
      1.6685f, -0.8874f, -0.7245f,
      -2.0706f, -0.0171f, 0.8533f, -0.6780f, 0.7322f, -0.2801f, -0.4571f,
      0.4468f, 1.8353f, -0.3647f,
      0.2750f, -1.1303f, 0.6455f, -0.6216f, 1.7108f, -0.6376f, 1.3264f,
      0.6033f, -1.2068f, -0.9646f,
      -1.2909f, 0.6121f, 1.4886f, -1.4065f, 0.2142f, 0.1850f, 0.3498f,
      1.4607f, -1.0593f, -0.5537f,
      -1.6073f, -0.6602f, 1.7037f, -0.2057f, 0.5179f, 0.1781f, 0.7954f,
      -1.5163f, -0.0971f, 0.8916f,
      -0.3329f, -0.3355f, 1.0987f, -1.9300f, 0.9568f, 0.5973f, -1.0527f,
      0.8440f, 1.0099f, -0.8555f), Array(2, 6, 10))
    TestUtils.conditionFailTest(output[Tensor[Float]](1).almostEqual(expect, 5e-3) == true)

    val gradOutput = Tensor[Float](Array[Float](21f, 52f, 1f, 87f, 29f, 37f,
      1f, 63f, 59f, 20f,
      32f, 75f, 57f, 21f, 88f, 48f, 90f, 58f, 41f, 91f,
      59f, 79f, 14f, 61f, 61f, 46f, 61f, 50f, 54f, 63f,
      2f, 50f, 6f, 20f, 72f, 38f, 17f, 3f, 88f, 59f,
      13f, 8f, 89f, 52f, 1f, 83f, 91f, 59f, 70f, 43f,
      7f, 46f, 34f, 77f, 80f, 35f, 49f, 3f, 1f, 5f,
      53f, 3f, 53f, 92f, 62f, 17f, 89f, 43f, 33f, 73f,
      61f, 99f, 13f, 94f, 47f, 14f, 71f, 77f, 86f, 61f,
      39f, 84f, 79f, 81f, 52f, 23f, 25f, 88f, 59f, 40f,
      28f, 14f, 44f, 64f, 88f, 70f, 8f, 87f, 0f, 7f,
      87f, 62f, 10f, 80f, 7f, 34f, 34f, 32f, 4f, 40f,
      27f, 6f, 72f, 71f, 11f, 33f, 32f, 47f, 22f, 61f
    ), Array(2, 6, 10))
    val grad2 = Tensor[Float](2, 10)

    layer.backward(finalInput, T(gradOutput, grad2))
    val grads = layer.parameters()._2

    val expectGrad = new Array[Tensor[Float]](grads.size - 2)
//    expectGrad(20 - 4) =
    expectGrad(20) =
      Tensor[Float](Array[Float](429f, 578f, 472f, 800f, 598f, 478f,
      568f, 610f, 517f, 563f), Array(1, 10))
//    expectGrad(19 - 4) =
    expectGrad(19) =
      Tensor[Float](Array[Float](-308.3242f, 73.9104f, 415.4557f,
      -667.1381f, 372.1354f, -32.4430f,
      -20.9703f, 388.4782f, -109.3576f, -259.4818f), Array(1, 10))
//    expectGrad(18 - 4) =
    expectGrad(18) =
      Tensor[Float](Array[Float](-102.6976f, -1.5482f, -77.3635f,
      190.8914f, 33.2574f, -53.2304f,
      -18.3125f, 52.5667f, -15.7316f, -7.8316f), Array(10))
//    expectGrad(17 - 4) =
    expectGrad(17) =
      Tensor[Float](Array[Float](-1.0580f, -3.3393f, 2.1110f, 0.1864f,
      2.7355f, -0.9865f, -0.4861f, 0.0079f,
      -7.1003f, -1.7762f, 2.8341f, -2.3958f, 2.1112f, 1.3432f, 1.8933f, 4.7895f,
      1.5008f, 3.9667f, 2.8231f, 2.4426f, -2.2035f, -5.1837f, 0.1428f, 0.7536f,
      -0.3784f, -4.3832f, -1.4782f, -1.7519f, -3.4882f, 0.3851f, -1.9032f, 3.0388f,
      1.0172f, 1.4051f, -0.3384f, -1.3813f, -2.3535f, 0.3197f, 5.5005f, 0.8471f,
      -1.7222f, 0.8297f, -2.9059f, -1.8899f, 1.4856f, -1.3102f, 0.0904f, 1.2437f,
      -1.9589f, 0.0078f, -5.3002f, 1.6773f, -1.4316f, -0.6492f, -2.8976f, 2.0590f,
      -1.2389f, -2.6494f, 2.4170f, -0.5960f, -1.1064f, 3.1222f, -0.7554f, -0.0974f,
      -1.5898f, -0.4975f, 0.3276f, 1.0791f, -0.1578f, 0.6827f, 1.2964f, -0.3858f,
      0.4937f, -0.5865f, 0.0835f, 0.7411f, 0.4751f, 0.6231f, -1.0914f, 0.7795f,
      1.5765f, 0.0451f, -1.2355f, 0.5161f, 0.3536f, -0.4878f, -0.5271f, 0.4877f,
      -0.9179f, -0.0761f, 0.2592f, 0.2591f, -1.0049f, 0.1577f, -0.7627f, 1.0529f,
      0.3782f, 0.3197f, -0.4623f, -0.4209f, 0.2576f, -0.8491f, 0.1410f, 0.3572f,
      -1.1737f, 0.9715f, -0.3530f, -0.7955f, -0.2649f, -1.4233f, 1.2601f, 0.6087f,
      -0.0484f, -0.2390f, -0.5762f, 0.3078f, -0.8644f, -0.1581f, -1.6243f, -0.5410f,
      1.3657f, -1.4342f, 1.5082f, -0.1304f, -1.8858f, 0.7989f, 0.3942f, -0.4180f,
      -0.7696f, -1.7046f, 3.0083f, 0.2815f, -1.8313f, 1.3514f, -2.1650f, 2.8038f,
      -2.3636f, -2.7893f, 0.6318f, -0.9494f, 0.3873f, 0.1486f, 3.6092f, 1.2046f,
      2.1854f, 4.0799f, 5.2118f, 1.8244f, -3.9598f, -0.3216f, 1.8582f, -2.5756f,
      0.3189f, -4.7066f, -4.6010f, 2.7134f, -1.8443f, 3.0252f, -2.6013f, 0.3272f,
      -0.1032f, 1.9697f, 0.1188f, -3.1437f, -3.9024f, -3.1310f, 3.4598f, 2.6818f,
      -1.4151f, -1.8735f, -3.6017f, 0.6903f, -0.8672f, -4.0220f, 2.4134f, 0.3420f,
      -0.2728f, 0.5919f, -0.9514f, 0.4869f, -1.1010f, -1.2868f, -7.0440f, -2.5691f,
      0.7652f, -0.6701f, 0.1827f, 4.1032f, -2.4273f, 4.2412f, 5.7438f, 3.6928f,
      3.9396f, 3.4900f, -3.4421f, -1.6823f, 2.0577f, -1.9800f, 2.9973f, 1.9318f,
      4.1292f, 2.3652f, -0.9422f, 0.0913f, -2.2378f, -0.5891f, -3.8551f, -3.8239f,
      -4.1143f, -5.2280f, -5.7310f, -2.7041f, 5.8461f, 4.9251f, 0.0422f, -0.9259f,
      2.2166f, 8.9307f, 4.4906f, 0.4724f, 3.5511f, -4.0872f, 4.0853f, 0.7849f,
      3.6977f, -3.8401f, -1.7302f, 4.5874f, 3.0021f, 3.5961f, -6.5022f, -3.4116f,
      6.3097f, 0.7883f, 3.9939f, 1.6307f, 0.6923f, 5.7231f, -3.0451f, -1.2674f,
      0.8538f, -1.0798f, 4.5407f, -6.3163f, -0.1319f, 0.6586f, 13.3009f, 1.0152f,
      -1.5308f, 2.8236f, -3.1783f, -1.5167f, 4.0155f, -7.3537f, -3.0804f, -4.1795f,
      0.1175f, 2.8403f, -4.0623f, -1.9653f, -2.6326f, 0.0631f, 2.2231f, -4.0464f,
      7.4562f, 6.3791f, -1.5825f, 5.3242f, -3.1754f, -4.8758f, -5.2144f, -2.5998f,
      -3.9367f, -5.1190f, -7.8036f, -4.2097f, 2.6958f, 2.5851f, -3.1077f, 3.8540f,
      -3.1926f, 5.2254f, 6.4354f, -0.6225f, 4.5647f, 0.0703f, 2.7995f, -4.7366f,
      -3.5315f, 0.3943f, 3.7462f, 2.5919f, 5.7584f, 2.9031f, -9.6024f, -1.1955f,
      -0.3986f, -0.3304f, 3.0981f, 0.2682f, -4.5044f, 4.6092f, -2.3424f, -0.2072f,
      -0.7548f, -4.0488f, 4.9132f, -0.4177f, 5.1700f, 0.9771f, 3.8022f, -0.8644f,
      1.6789f, 1.0282f, -0.9670f, -1.0238f, 0.8577f, -0.9700f, -3.2047f, 1.3942f,
      1.7744f, -0.9562f, -0.9035f, 1.2531f, -0.7631f, 1.9583f, -1.3938f, 0.8644f,
      -3.4384f, 1.2605f, 1.2901f, -0.3681f, 0.2464f, -1.2919f, 5.0387f, 3.0616f,
      -0.9015f, 1.1726f, 3.6620f, -0.6271f, -4.8742f, -2.1389f, 0.4338f, -2.4028f,
      -1.4068f, -2.2885f, -3.3989f, -1.2002f, -1.7157f, 1.1290f, -0.2544f, -3.0880f,
      -0.4894f, 3.8044f, 0.0978f, -0.1729f, -0.3444f, 1.5584f, 2.1371f, 1.8803f,
      0.0344f, -3.5890f, -0.7095f, -4.1357f, -0.6428f, -0.0654f, -0.7688f, -1.6236f,
      1.7461f, 0.7704f, 1.9783f, 2.1733f, -2.9364f, 0.3247f, -3.0133f, 0.3958f,
      0.5420f, 0.7125f, -2.0652f, 2.6675f, -3.1209f, 0.9533f, 5.4835f, 0.7168f,
      -0.9554f, -0.0771f, 1.2071f, 1.8164f, -0.1728f, 0.5448f, -3.6251f, -0.6047f,
      -3.1048f, -2.7448f, -0.2879f, -1.9916f, 2.5872f, 0.3954f, 0.7772f, 0.2551f,
      1.6153f, 3.3852f, 3.2086f, 1.3703f, -1.4906f, -3.4722f, -0.5200f, 0.7647f,
      1.2109f, -4.4440f, -3.0299f, -0.3078f, -2.3950f, 0.4259f, -3.0080f, 1.9661f,
      -2.0276f, -2.2776f, 1.7251f, -3.3858f, -1.8279f, -1.6760f, 5.6411f, -0.6422f,
      -3.8286f, 2.0208f, 1.5643f, -0.5905f, 2.8658f, -2.5222f, 1.4525f, 0.2251f,
      -1.4085f, 2.4551f, -5.2825f, 3.3492f, 1.0270f, -0.1826f, -4.5002f, 0.5604f,
      -2.0224f, -0.9700f, 1.5626f, -1.8220f, 0.8803f, 1.6179f, -1.7712f, -0.4249f,
      0.9409f, 1.6270f, -1.2174f, 0.4130f, -0.0277f, -0.5105f, 1.5623f, 0.0643f,
      1.9104f, 0.6550f, -0.5127f, 0.3411f, -0.5124f, -1.0632f, -2.5121f, -0.4328f,
      -1.7070f, -2.0395f, -2.8417f, -0.2785f, 1.6442f, 2.1625f, 0.0223f, -0.4180f,
      0.2658f, 3.0633f, 2.1107f, 0.2655f, 0.5619f, -1.6916f, 0.8388f, 0.8603f,
      2.1854f, -1.0753f, -0.7817f, 2.1608f, 1.2299f, 1.7812f, -2.6473f, -0.7186f,
      2.0909f, 0.7126f, 1.6350f, -0.6132f, -0.4510f, 1.7137f, -0.9554f, -0.4657f,
      0.0787f, -0.8713f, 1.8393f, -2.3060f, -0.8735f, -0.1183f, 4.4280f, 0.0949f,
      0.3115f, 0.0815f, -0.3829f, -0.4065f, 0.8010f, -2.9417f, -0.9546f, -2.0959f,
      -1.1986f, -1.9704f, 2.0208f, -1.3183f, 1.1335f, -0.7893f, 0.7183f, 0.4771f,
      0.7428f, -1.7938f, -0.3657f, -0.8388f, -0.5159f, 3.9766f, 2.0015f, -1.2136f,
      2.4146f, -0.3970f, 2.1009f, 1.0788f, 1.0780f, 2.1435f, 1.8238f, -1.0402f,
      1.0956f, -0.2508f, -0.4935f, 0.2195f, 1.1429f, 0.2254f, 1.1184f, -0.6708f,
      0.9235f, 1.0787f, -2.7879f, 0.6198f, -1.6227f, -2.4794f, 0.7840f, 0.7097f,
      1.5439f, -0.9371f, -3.3197f, 3.4226f, 0.5496f, -1.6191f, 1.2957f, 0.6614f,
      1.6252f, 1.1054f, 0.5792f, -0.2579f, -0.9100f, 0.0475f, -1.0327f, -0.1079f,
      0.7437f, 0.6119f, 0.2052f, 0.4673f, 0.4120f, 0.1858f, 0.9204f, 1.1383f,
      -1.2011f, 0.5879f, 0.9507f, -0.0636f, -0.3413f, -0.3340f, -1.1274f, -1.1123f,
      1.2748f, -0.9691f, -1.1485f, 0.0460f, 0.6342f, 1.3331f, -0.6470f, -2.0203f,
      1.3669f, 0.1339f, 0.6053f, 0.5873f, 0.9103f, -0.2120f, -0.1684f, 1.5026f,
      0.7879f, -1.0703f, -0.2945f, -0.0474f, 0.6274f, 0.3603f, -0.3126f, 0.4654f,
      -2.0503f, -1.7789f, 0.4124f, -1.4554f, -0.1972f, -2.0230f, 1.0883f, -0.5082f,
      -1.4407f, 1.4072f, 0.5985f, 2.0130f, 1.1370f, -1.0839f, 0.5995f, 0.4831f,
      0.1396f, 1.3082f, -1.7404f, 1.3033f, 2.0517f, 0.3871f, -1.4191f, -0.0428f,
      -0.6149f, 0.4659f, 0.7178f, -1.7426f, 1.5739f, 0.3461f, -2.7757f, 0.2737f), Array(10, 64))
//    expectGrad(16 - 4) =
    expectGrad(16) =
      Tensor[Float](Array[Float](-3.3787f, -2.6050f, -1.0664f, -2.2213f,
      1.6643f, -2.3580f, 1.3027f, 2.4582f,
      1.4468f, -1.7931f, 1.6253f, -3.0072f, 6.0387f, -0.3518f, 0.6676f, 1.1960f,
      3.7489f, 2.8823f, -1.1145f, 4.6594f, 1.1728f, -3.7206f, 4.0018f, -1.1291f,
      2.0657f, 1.3442f, -1.2284f, 2.0001f, -2.8781f, 2.8458f, 3.0325f, 0.1283f,
      3.6492f, 0.2572f, -0.3039f, -2.2268f, -1.2262f, 3.4637f, -0.7994f, 3.4880f,
      0.3507f, 2.5811f, 3.8588f, 0.7987f, -5.3275f, 1.7393f, 1.4253f, 0.3357f,
      1.0857f, -0.7352f, -2.9298f, -1.1984f, 4.9273f, 2.2179f, 5.9523f, -3.1779f,
      0.4895f, -2.3895f, -0.2659f, 2.5864f, -1.4725f, 2.9910f, 2.7749f, -0.5007f), Array(64))
//    expectGrad(15 - 4) =
    expectGrad(15) =
      Tensor[Float](Array[Float]( 0.5872f, -2.5162f, -1.3089f, 2.6942f,
      -1.8444f, 2.8932f, -0.6094f, -4.4749f, 2.9703f, 1.6089f,
      -0.1863f, -1.3689f, -1.6873f, 4.4071f, -1.3313f, -0.6168f, -1.6696f, -3.1148f,
      3.2937f, 2.2741f,
      -0.2983f, -2.7319f, -0.3790f, 2.0693f, 1.2936f, 1.1336f, 1.3378f, -4.2948f,
      2.4854f, -0.6158f,
      5.3777f, -0.1691f, -0.8867f, -1.6383f, -2.5449f, 4.1551f, -2.7639f, 3.8600f,
      -0.2083f, -5.1816f,
      0.7005f, -0.6102f, 1.0706f, -2.8594f, 0.7100f, 0.2508f, -0.9104f, 2.8077f,
      -0.3110f, -0.8486f,
      -1.1853f, 0.5420f, -3.8510f, 1.8732f, -2.0878f, 0.8803f, -0.5686f, -1.4087f,
      3.2729f, 2.5331f,
      2.5187f, -3.1172f, 3.1268f, -1.1514f, 1.4862f, 0.1779f, 0.8737f, -0.5114f,
      -1.8580f, -1.5452f,
      1.9134f, -1.2150f, 1.7987f, -3.1422f, 2.9289f, -0.5423f, 0.8452f, 2.7739f,
      -2.0747f, -3.2858f,
      0.7538f, 0.8152f, 1.5800f, -1.5531f, 0.9392f, -0.1448f, -0.2932f, 3.1015f,
      -1.5604f, -3.6383f,
      -1.0012f, 1.5065f, -2.7656f, -0.2901f, -3.1118f, 1.8041f, -1.2266f, 0.3286f,
      2.1675f, 2.5886f,
      -0.3267f, -0.4473f, 1.2662f, -0.9423f, 2.1091f, -2.0737f, -0.9209f, 3.2504f,
      -0.6298f, -1.2849f,
      3.3928f, -1.9985f, -0.3978f, 1.1760f, -2.8006f, 3.3489f, 0.7413f, -3.3493f,
      -0.3098f, 0.1969f,
      -0.1599f, 6.2460f, 2.0093f, -7.0386f, 3.5855f, -3.3381f, -1.7027f, 13.6008f,
      -4.9608f, -8.2415f,
      0.3760f, -1.8128f, 0.7492f, 1.8306f, 1.2593f, -1.0826f, 2.7678f, -3.1596f,
      -1.2895f, 0.3616f,
      -0.3199f, -1.0371f, 1.8375f, -0.1672f, 0.6072f, -0.1967f, 0.8958f, -1.7211f,
      -1.1387f, 1.2402f,
      -0.0228f, -2.7427f, 3.4076f, -0.8268f, 1.3378f, 0.0885f, 2.4008f, -2.9111f,
      -2.5866f, 1.8553f,
      -1.1236f, 1.8388f, -0.0614f, -2.7218f, 3.9897f, -3.7199f, -0.3056f, 6.6764f,
      -0.3387f, -4.2340f,
      -2.2042f, 1.6541f, 2.7283f, -0.3833f, 1.8652f, -3.3456f, 1.6011f, -0.1532f,
      -3.2342f, 1.4717f,
      -0.0046f, -1.9237f, -0.9333f, 0.9282f, -0.1384f, 1.2859f, -0.6657f, -2.6732f,
      2.6262f, 1.4985f,
      0.8256f, 1.7435f, 4.3902f, -4.6308f, 2.9955f, -2.1035f, 1.4411f, 5.9252f,
      -5.9105f, -4.6765f,
      -0.2924f, -0.1363f, 0.4996f, 0.0502f, 1.9162f, -2.3093f, 1.8330f, -0.5018f,
      -1.6993f, 0.6401f,
      -3.3194f, -0.0227f, -5.3374f, 5.2252f, -2.2696f, -0.5928f, 2.7140f, -6.8166f,
      3.6890f, 6.7303f,
      -0.2722f, 1.3635f, 4.4126f, -4.4627f, 2.2881f, -0.2952f, -1.4799f, 5.4720f,
      -2.9933f, -4.0329f,
      -0.3378f, 1.0920f, -1.8697f, 2.9920f, -0.2655f, -2.2326f, 1.7312f, -1.7645f,
      -0.4407f, 1.0956f,
      -4.2179f, 0.8622f, -1.1042f, 4.3262f, 4.9683f, -6.3190f, 0.3592f, -0.4476f,
      3.4713f, -1.8986f,
      -0.8469f, 1.6589f, -0.1070f, -1.3429f, 0.1628f, -1.2248f, 2.9341f, 0.0421f,
      -2.7314f, 1.4552f,
      2.3185f, -3.4832f, 1.0468f, -0.5750f, -0.2343f, 3.2997f, 0.7922f, -2.8669f,
      0.1272f, -0.4251f,
      -0.9879f, -0.9914f, 1.9991f, 0.4696f, 3.2606f, -2.0653f, -0.8552f, 1.5247f,
      0.9491f, -3.3035f,
      2.4984f, -1.2291f, -1.6790f, 0.8785f, -2.1509f, 3.4557f, -0.6182f, -1.9253f,
      1.5813f, -0.8114f,
      1.0926f, 2.4823f, 1.2937f, -5.6106f, 0.1906f, 0.2032f, 0.1359f, 6.0852f,
      -3.8113f, -2.0615f,
      -2.0084f, -3.0505f, 2.3530f, 2.8234f, 5.8682f, -6.4606f, 1.7312f, -2.5815f,
      0.1848f, 1.1403f,
      2.6273f, -0.1101f, -0.0103f, -2.3786f, 0.0151f, 0.9065f, 0.5685f, 3.0627f,
      -2.3250f, -2.3562f,
      -1.7334f, -0.6855f, 3.9414f, -1.9662f, 2.8874f, -1.9005f, 2.7269f, -0.9034f,
      -3.1799f, 0.8132f,
      -0.0157f, 2.7101f, -0.4228f, -0.1642f, -0.5389f, -0.5992f, -0.3527f, 2.1748f,
      -1.2474f, -1.5438f,
      -2.4542f, -3.6721f, -0.4022f, 5.6782f, 2.4936f, -3.6093f, 2.4183f, -8.4913f,
      3.5164f, 4.5226f,
      -0.8921f, -1.2243f, -2.9788f, 4.7178f, -0.7807f, -0.6303f, 0.3070f, -6.0139f,
      4.3912f, 3.1042f,
      1.4402f, -1.7848f, 0.7219f, 1.3682f, 0.5861f, 0.5263f, -0.7392f, -0.4232f,
      0.1029f, -1.7984f,
      1.4850f, 0.4200f, 2.1811f, -2.9013f, 3.7298f, -2.4677f, 0.5097f, 5.2131f,
      -3.1225f, -5.0472f,
      -0.6444f, 0.8013f, -1.1430f, 0.0884f, -1.8921f, 2.1773f, -0.6861f, -1.1477f,
      2.0537f, 0.3926f,
      -2.3740f, 2.7361f, 0.4183f, -1.0867f, 4.1195f, -4.9933f, 0.7930f, 4.3530f,
      -2.0071f, -1.9588f,
      -2.4424f, -0.1163f, 0.1914f, 1.4941f, 0.0591f, -1.4626f, 1.3336f, -3.0171f,
      0.4350f, 3.5251f,
      -0.6471f, 1.7754f, 1.3374f, -0.8439f, 2.6679f, -3.1142f, 0.8498f, 3.1290f,
      -2.3885f, -2.7659f,
      -2.1548f, 1.9963f, 0.3507f, -2.1811f, 3.7489f, -3.4828f, -0.4870f, 4.8545f,
      0.3124f, -2.9571f,
      1.6489f, -0.9306f, 1.0144f, -2.2118f, 1.3213f, -0.3458f, -0.6223f, 3.3251f,
      -1.7114f, -1.4878f,
      0.7084f, -0.9955f, -4.0932f, 2.0278f, -6.8458f, 5.7611f, -0.7183f, -5.0916f,
      4.0746f, 5.1726f,
      -0.2490f, 0.6075f, 0.1680f, 3.1275f, 4.7316f, -7.0229f, 2.4051f, 1.4284f,
      -3.0204f, -2.1757f,
      -3.7167f, 2.5645f, -0.9343f, 0.1587f, -0.2478f, -1.4373f, 0.2454f, -1.1365f,
      1.3674f, 3.1366f,
      -0.1489f, -0.9223f, 0.1215f, 0.3527f, 0.7283f, -0.7326f, 0.2586f, -0.7780f,
      0.5384f, 0.5824f,
      0.5819f, 0.0882f, -1.1666f, -1.8632f, 1.9912f, -1.3875f, 1.7966f, 3.3693f,
      -1.2368f, -2.1730f,
      -2.4313f, 0.3267f, -1.1114f, 1.9005f, -0.8745f, -0.7353f, -0.0859f, -2.0432f,
      1.6968f, 3.3577f,
      -2.4888f, 1.1940f, -5.1665f, 5.2095f, -2.2262f, -2.6014f, -0.3177f, -4.2252f,
      3.9821f, 6.6401f,
      -0.9992f, -0.0238f, -0.9852f, 1.5757f, -1.7247f, 0.6126f, -1.1643f, -2.0062f,
      1.8744f, 2.8408f,
      -0.9175f, 3.0243f, 2.9183f, -5.7390f, 3.0473f, -1.5971f, 0.5185f, 6.8529f,
      -4.4063f, -3.7014f,
      0.5587f, -1.2656f, 1.4929f, -1.4041f, 2.9816f, -2.6508f, 3.6969f, -0.1655f,
      -3.3543f, 0.1102f,
      1.0491f, -1.6032f, 4.5694f, -2.7540f, 7.6610f, -5.8144f, 2.0148f, 5.1709f,
      -5.0915f, -5.2022f,
      -1.7095f, -0.0718f, -1.9150f, 4.5710f, -3.3039f, 1.1465f, -0.0588f, -6.3460f,
      3.2328f, 4.4547f,
      -0.6396f, -1.1049f, 0.6717f, 1.0309f, 1.3428f, -1.7577f, -1.3764f, 0.0852f,
      1.2269f, 0.5212f,
      -0.4662f, -3.3393f, -2.2198f, 3.9423f, -0.6190f, -0.5915f, 2.9082f, -7.3044f,
      2.4547f, 5.2350f,
      -3.4249f, 1.1559f, -1.4834f, 2.2900f, -0.3235f, -0.4688f, 1.2756f, -3.9352f,
      2.4816f, 2.4329f,
      -3.4417f, -1.1064f, 1.1508f, 1.2264f, 4.5061f, -4.3013f, 0.7740f, -1.2991f,
      1.2166f, 1.2747f,
      0.8694f, -0.4399f, -1.7460f, 1.3607f, -0.4675f, 0.4738f, 0.8979f, -0.9408f,
      0.9030f, -0.9105f,
      -3.6571f, 0.7072f, 0.6669f, 0.7547f, 4.1650f, -4.8541f, 1.9653f, 0.7386f,
      -0.3924f, -0.0940f,
      -2.9875f, 0.2738f, 3.7445f, -2.2234f, 1.2804f, -1.4137f, 1.7017f, -0.4212f,
      -3.1080f, 3.1534f,
      -0.4814f, -1.4524f, 0.6254f, 1.4415f, 0.2565f, 0.7982f, -0.3457f, -2.5523f,
      1.6931f, 0.0170f), Array(64, 10))
//    expectGrad(14 - 4) =
    expectGrad(14) =
      Tensor[Float](Array[Float](-102.0154f, -2.0866f, -78.0799f,
      190.8661f, 33.5124f, -53.4546f,
      -18.0541f, 52.2774f, -15.8880f, -7.4727f), Array(1, 10))
//    expectGrad(13 - 4) =
    expectGrad(13) =
      Tensor[Float](Array[Float](8.0734f, 8.6608f, -22.1519f,
      -103.3919f, -39.7637f, 128.0594f,
      -29.7084f, 55.4173f, 14.1729f, -19.3239f), Array(1, 10))
//    expectGrad(12 - 4) =
    expectGrad(12) =
      Tensor[Float](Array[Float](-102.0217f, -2.0389f, -78.0332f,
      190.8739f, 33.5608f, -53.3832f,
      -18.0584f, 52.2859f, -15.7794f, -7.4057f), Array(10))
//    expectGrad(11 - 4) =
    expectGrad(11) =
      Tensor[Float](Array[Float](2.4003f, -3.1011f, 4.0827f, 0.9945f,
      1.6197f, 4.0893f, 1.0422f, -6.7701f, -1.8708f, -2.9517f,
      1.5749f, -0.0829f, 1.1692f, -1.4690f, 0.6115f, 0.1689f, 0.9442f,
      1.4587f, -0.0923f, -0.8213f,
      1.4739f, -2.3674f, 2.8538f, 1.1234f, 1.0915f, 3.1195f, 0.5770f,
      -5.5633f, -1.4419f, -2.0806f,
      -10.1726f, 5.8948f, -11.6371f, 3.6610f, -5.1433f, -7.9841f, -5.3870f,
      6.7196f, 3.7261f, 8.3765f,
      -0.3860f, 1.0025f, -1.0810f, -0.6922f, -0.4246f, -1.3201f, -0.1100f,
      2.6304f, 0.5997f, 0.7783f,
      4.0834f, -1.6631f, 4.1273f, -2.2208f, 1.8997f, 2.3060f, 2.2578f,
      -0.5874f, -1.0922f, -2.9576f,
      1.1177f, -0.5556f, 1.2305f, -0.5313f, 0.5827f, 0.7577f, 0.6115f,
      -0.4469f, -0.3447f, -0.8855f,
      -2.6446f, 1.6102f, -3.0874f, 0.8600f, -1.3581f, -2.1810f, -1.3864f,
      1.9967f, 1.0186f, 2.2203f,
      2.6249f, -0.5141f, 2.2167f, -2.0290f, 1.0899f, 0.7656f, 1.5283f,
      1.3067f, -0.3823f, -1.5783f,
      -0.0717f, -0.2233f, 0.1252f, 0.3035f, 0.0310f, 0.2783f, -0.0776f,
      -0.7444f, -0.1201f, -0.1000f), Array(10, 10))

    expectGrad(10) = Tensor[Float](Array[Float](0.3589f, 0.4846f, -0.6027f,
      9.1532f, -3.3322f, 0.7664f, 2.8848f, -6.0095f,
      -2.2415f, 6.5400f), Array(10))
    expectGrad(9) = Tensor[Float](Array[Float](-1.7866f, -3.2862f, 1.7699f,
      0.2298f, 2.0777f, -1.4149f, 0.9542f, 0.0221f, 1.8769f, -0.4430f,
      -0.6441f, -0.7191f, 0.7529f, -0.2866f, 0.7656f, -0.4070f, 0.2325f, 0.2394f,
      0.3241f, -0.2578f,
      0.6840f, 0.6616f, -0.8314f, 0.3911f, -0.8294f, 0.4195f, -0.2283f, -0.3129f,
      -0.2573f, 0.3031f,
      -6.4642f, -1.4962f, 8.9090f, -7.4525f, 7.8305f, -2.7669f, 0.7754f, 5.3710f,
      -0.8061f, -3.9001f,
      1.1580f, -1.9868f, -2.1267f, 3.1284f, -1.4594f, -0.0207f, 0.4580f, -2.1208f,
      1.7725f, 1.1973f,
      -0.1231f, 0.7463f, 0.3599f, -0.7667f, 0.1794f, 0.1171f, -0.1847f, 0.5071f,
      -0.5829f, -0.2525f,
      -1.9889f, -0.3771f, 2.7664f, -2.3662f, 2.4143f, -0.8295f, 0.2176f, 1.6964f,
      -0.3128f, -1.2203f,
      2.6044f, -2.4214f, -4.3527f, 5.4365f, -3.2773f, 0.4465f, 0.4439f, -3.7094f,
      2.5734f, 2.2560f,
      1.5517f, 0.3009f, -2.1554f, 1.8380f, -1.8780f, 0.6440f, -0.1707f, -1.3124f,
      0.2364f, 0.9456f,
      -2.7177f, 2.8929f, 4.6267f, -5.9531f, 3.4265f, -0.3818f, -0.5521f, 4.0397f,
      -2.9541f, -2.4271f), Array(10, 10))
    expectGrad(8) = Tensor[Float](Array[Float](0.0000f, -0.0000f, 0.0000f, 0.0000f,
      0.0000f, -0.0000f, 0.0000f, -0.0000f, 0.0000f, 0.0000f), Array(10))
    expectGrad(7) = Tensor[Float](Array[Float](0.0080f, 0.0061f, -0.0098f, 0.0001f,
      0.0032f, -0.0079f, -0.0049f, 0.0140f, -0.0031f, -0.0058f,
      0.0077f, 0.0045f, -0.0067f, -0.0003f, 0.0014f, -0.0057f, 0.0027f, 0.0072f,
      -0.0102f, -0.0008f,
      -0.0026f, 0.0046f, -0.0021f, 0.0001f, -0.0005f, -0.0017f, 0.0001f, -0.0004f,
      0.0001f, 0.0024f,
      -0.0030f, 0.0250f, -0.0207f, 0.0004f, 0.0013f, -0.0149f, -0.0011f, 0.0147f,
      -0.0054f, 0.0038f,
      -0.0033f, 0.0138f, -0.0095f, -0.0004f, 0.0008f, -0.0062f, -0.0077f, 0.0119f,
      0.0026f, -0.0021f,
      -0.0041f, -0.0010f, 0.0002f, 0.0066f, 0.0017f, -0.0034f, 0.0042f, -0.0036f,
      -0.0030f, 0.0024f,
      -0.0051f, 0.0016f, -0.0014f, -0.0011f, -0.0019f, 0.0018f, -0.0036f, 0.0024f,
      0.0046f, 0.0027f,
      -0.0015f, 0.0024f, -0.0063f, 0.0029f, 0.0028f, -0.0053f, -0.0026f, 0.0050f,
      0.0022f, 0.0004f,
      0.0010f, 0.0021f, -0.0040f, -0.0018f, 0.0010f, -0.0018f, -0.0034f, 0.0044f,
      0.0031f, -0.0005f,
      -0.0092f, -0.0015f, 0.0076f, -0.0002f, -0.0057f, 0.0073f, 0.0027f, -0.0060f,
      0.0002f, 0.0047f), Array(10, 10))
    expectGrad(6) = Tensor[Float](Array[Float](0.0013f, -0.0066f, -0.0002f, 0.0076f,
      0.0070f, -0.0041f, -0.0065f, -0.0103f,
      0.0085f, 0.0022f), Array(10))
    expectGrad(5) = Tensor[Float](Array[Float]( 0.0001f, 0.0019f, -0.0018f, -0.0019f,
      0.0014f, -0.0010f, -0.0005f, 0.0061f, 0.0001f, -0.0043f,
      -0.0042f, -0.0064f, -0.0023f, 0.0140f, -0.0014f, -0.0049f, 0.0085f, -0.0165f,
      0.0007f, 0.0123f,
      -0.0027f, -0.0011f, 0.0032f, 0.0060f, 0.0001f, -0.0058f, 0.0046f, -0.0072f,
      -0.0046f, 0.0075f,
      -0.0013f, 0.0030f, -0.0026f, -0.0060f, 0.0096f, -0.0084f, -0.0067f, 0.0242f,
      0.0040f, -0.0156f,
      0.0036f, 0.0054f, -0.0015f, -0.0130f, 0.0047f, 0.0016f, -0.0077f, 0.0234f,
      0.0015f, -0.0181f,
      0.0027f, 0.0032f, -0.0101f, 0.0100f, -0.0001f, -0.0063f, 0.0020f, -0.0064f,
      0.0037f, 0.0012f,
      -0.0010f, -0.0001f, -0.0065f, 0.0028f, -0.0063f, 0.0030f, 0.0005f, -0.0007f,
      0.0037f, 0.0047f,
      0.0030f, -0.0022f, -0.0121f, 0.0127f, -0.0072f, -0.0008f, 0.0025f, -0.0148f,
      0.0074f, 0.0115f,
      0.0019f, -0.0024f, 0.0109f, -0.0038f, 0.0068f, -0.0026f, 0.0011f, -0.0021f,
      -0.0068f, -0.0031f,
      0.0000f, -0.0011f, 0.0047f, -0.0051f, 0.0001f, 0.0023f, 0.0001f, 0.0037f,
      -0.0035f, -0.0013f), Array(10, 10))

//    val vb = Tensor[Float](Array[Float](0.3589f, 0.4846f, -0.6027f,
//      9.1532f, -3.3322f, 0.7664f, 2.8848f, -6.0095f,
//      -2.2415f, 6.5400f), Array(10))
//    val vw = Tensor[Float](Array[Float](-1.7866f, -3.2862f, 1.7699f,
//      0.2298f, 2.0777f, -1.4149f, 0.9542f, 0.0221f, 1.8769f, -0.4430f,
//      -0.6441f, -0.7191f, 0.7529f, -0.2866f, 0.7656f, -0.4070f, 0.2325f, 0.2394f,
//      0.3241f, -0.2578f,
//      0.6840f, 0.6616f, -0.8314f, 0.3911f, -0.8294f, 0.4195f, -0.2283f, -0.3129f,
//      -0.2573f, 0.3031f,
//      -6.4642f, -1.4962f, 8.9090f, -7.4525f, 7.8305f, -2.7669f, 0.7754f, 5.3710f,
//      -0.8061f, -3.9001f,
//      1.1580f, -1.9868f, -2.1267f, 3.1284f, -1.4594f, -0.0207f, 0.4580f, -2.1208f,
//      1.7725f, 1.1973f,
//      -0.1231f, 0.7463f, 0.3599f, -0.7667f, 0.1794f, 0.1171f, -0.1847f, 0.5071f,
//      -0.5829f, -0.2525f,
//      -1.9889f, -0.3771f, 2.7664f, -2.3662f, 2.4143f, -0.8295f, 0.2176f, 1.6964f,
//      -0.3128f, -1.2203f,
//      2.6044f, -2.4214f, -4.3527f, 5.4365f, -3.2773f, 0.4465f, 0.4439f, -3.7094f,
//      2.5734f, 2.2560f,
//      1.5517f, 0.3009f, -2.1554f, 1.8380f, -1.8780f, 0.6440f, -0.1707f, -1.3124f,
//      0.2364f, 0.9456f,
//      -2.7177f, 2.8929f, 4.6267f, -5.9531f, 3.4265f, -0.3818f, -0.5521f, 4.0397f,
//      -2.9541f, -2.4271f), Array(10, 10))
//    val kb = Tensor[Float](Array[Float](0.0000f, -0.0000f, 0.0000f, 0.0000f,
//      0.0000f, -0.0000f, 0.0000f, -0.0000f, 0.0000f, 0.0000f), Array(10))
//    val kw = Tensor[Float](Array[Float](0.0080f, 0.0061f, -0.0098f, 0.0001f,
//      0.0032f, -0.0079f, -0.0049f, 0.0140f, -0.0031f, -0.0058f,
//      0.0077f, 0.0045f, -0.0067f, -0.0003f, 0.0014f, -0.0057f, 0.0027f, 0.0072f,
//      -0.0102f, -0.0008f,
//      -0.0026f, 0.0046f, -0.0021f, 0.0001f, -0.0005f, -0.0017f, 0.0001f, -0.0004f,
//      0.0001f, 0.0024f,
//      -0.0030f, 0.0250f, -0.0207f, 0.0004f, 0.0013f, -0.0149f, -0.0011f, 0.0147f,
//      -0.0054f, 0.0038f,
//      -0.0033f, 0.0138f, -0.0095f, -0.0004f, 0.0008f, -0.0062f, -0.0077f, 0.0119f,
//      0.0026f, -0.0021f,
//      -0.0041f, -0.0010f, 0.0002f, 0.0066f, 0.0017f, -0.0034f, 0.0042f, -0.0036f,
//      -0.0030f, 0.0024f,
//      -0.0051f, 0.0016f, -0.0014f, -0.0011f, -0.0019f, 0.0018f, -0.0036f, 0.0024f,
//      0.0046f, 0.0027f,
//      -0.0015f, 0.0024f, -0.0063f, 0.0029f, 0.0028f, -0.0053f, -0.0026f, 0.0050f,
//      0.0022f, 0.0004f,
//      0.0010f, 0.0021f, -0.0040f, -0.0018f, 0.0010f, -0.0018f, -0.0034f, 0.0044f,
//      0.0031f, -0.0005f,
//      -0.0092f, -0.0015f, 0.0076f, -0.0002f, -0.0057f, 0.0073f, 0.0027f, -0.0060f,
//      0.0002f, 0.0047f), Array(10, 10))
//    val qb = Tensor[Float](Array[Float](0.0013f, -0.0066f, -0.0002f, 0.0076f,
//      0.0070f, -0.0041f, -0.0065f, -0.0103f,
//      0.0085f, 0.0022f), Array(10))
//    val qw = Tensor[Float](Array[Float]( 0.0001f, 0.0019f, -0.0018f, -0.0019f,
//      0.0014f, -0.0010f, -0.0005f, 0.0061f, 0.0001f, -0.0043f,
//      -0.0042f, -0.0064f, -0.0023f, 0.0140f, -0.0014f, -0.0049f, 0.0085f, -0.0165f,
//      0.0007f, 0.0123f,
//      -0.0027f, -0.0011f, 0.0032f, 0.0060f, 0.0001f, -0.0058f, 0.0046f, -0.0072f,
//      -0.0046f, 0.0075f,
//      -0.0013f, 0.0030f, -0.0026f, -0.0060f, 0.0096f, -0.0084f, -0.0067f, 0.0242f,
//      0.0040f, -0.0156f,
//      0.0036f, 0.0054f, -0.0015f, -0.0130f, 0.0047f, 0.0016f, -0.0077f, 0.0234f,
//      0.0015f, -0.0181f,
//      0.0027f, 0.0032f, -0.0101f, 0.0100f, -0.0001f, -0.0063f, 0.0020f, -0.0064f,
//      0.0037f, 0.0012f,
//      -0.0010f, -0.0001f, -0.0065f, 0.0028f, -0.0063f, 0.0030f, 0.0005f, -0.0007f,
//      0.0037f, 0.0047f,
//      0.0030f, -0.0022f, -0.0121f, 0.0127f, -0.0072f, -0.0008f, 0.0025f, -0.0148f,
//      0.0074f, 0.0115f,
//      0.0019f, -0.0024f, 0.0109f, -0.0038f, 0.0068f, -0.0026f, 0.0011f, -0.0021f,
//      -0.0068f, -0.0031f,
//      0.0000f, -0.0011f, 0.0047f, -0.0051f, 0.0001f, 0.0023f, 0.0001f, 0.0037f,
//      -0.0035f, -0.0013f), Array(10, 10))
//
//    expectGrad(6) = Tensor[Float](30)
//    expectGrad(6).narrow(1, 1, 10).copy(qb)
//    expectGrad(6).narrow(1, 11, 10).copy(kb)
//    expectGrad(6).narrow(1, 21, 10).copy(vb)
//    expectGrad(5) = Tensor[Float](30, 10)
//    expectGrad(5).narrow(1, 1, 10).copy(qw)
//    expectGrad(5).narrow(1, 11, 10).copy(kw)
//    expectGrad(5).narrow(1, 21, 10).copy(vw)

    expectGrad(4) = Tensor[Float](Array[Float](-101.5845f, -2.2855f, -78.2359f,
      190.9308f, 33.9887f, -53.8646f,
      -17.9227f, 52.5855f, -14.8049f, -7.6130f), Array(1, 10))
    expectGrad(3) = Tensor[Float](Array[Float](7.9889f, 8.6883f, -22.1974f,
      -103.4091f, -39.6225f, 128.1670f,
      -29.7590f, 55.7228f, 13.7011f, -19.2770f), Array(1, 10))
    expectGrad(2) = Tensor[Float](Array[Float](-4011.8169f, 1755.9337f,
      -4640.7207f, 8251.9453f, 535.7934f, -2597.7024f,
      -1519.1013f, 2743.7759f, 761.9067f, -1280.0115f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f), Array(2, 10))
    expectGrad(0) = Tensor[Float](Array[Float](-876.0627f, -532.3563f, -1951.8788f,
      3806.3535f, -230.7501f, -795.2228f,
      -905.6954f, 1041.7300f, 731.2013f, -287.3188f,
      -504.9723f, 2256.0664f, -2922.6753f, 781.3654f, -107.3239f, -3153.2095f,
      1514.8699f, 521.5922f, 820.7585f, 793.5289f,
      -534.9454f, 1653.0405f, -566.9773f, 1131.5408f, 155.1542f, -1590.5498f,
      -920.5416f, 1007.6495f, 55.7195f, -390.0904f,
      -1300.1160f, -746.0009f, -1037.4202f, 414.5659f, 3042.9480f, 1262.3949f,
      -2045.1296f, -40.0371f, 948.7982f, -500.0036f,
      336.6606f, -831.4297f, 454.6042f, 1655.4795f, -3231.9053f, 1074.3910f,
      1437.2823f, -30.0840f, -631.9481f, -233.0501f,
      -1132.3813f, -43.3866f, 1383.6260f, 462.6400f, 907.6707f, 604.4938f,
      -599.8870f, 242.9250f, -1162.6229f, -663.0775f), Array(6, 10))
    expectGrad(1) = Tensor[Float](Array[Float](0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      -385.6432f, -1021.5302f, 1211.8705f, 612.4370f, -635.6130f, -43.0432f,
      -379.6112f, 416.1957f, -296.2715f, 521.2089f,
      27.2902f, -1129.9083f, -440.0944f, 1277.2582f, 1511.2589f, 915.7172f,
      -1234.0095f, 1025.8958f, -1006.9253f, -946.4828f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      -1590.9420f, 1806.2361f, -1804.8018f, 2566.0967f, 1146.3241f, 607.3188f,
      -2168.6509f, 1267.2148f, 292.6887f, -2121.4844f,
      336.6606f, -831.4297f, 454.6042f, 1655.4795f, -3231.9053f, 1074.3910f,
      1437.2823f, -30.0840f, -631.9481f, -233.0501f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      -2186.9043f, 3368.6128f, -2483.6807f, 1741.8201f, 753.9789f, -3819.4248f,
      -1388.7861f, 1048.4204f, 3380.5061f, -414.5416f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      97.5968f, 779.4239f, -1467.5935f, 310.2924f, 121.9856f, -235.4992f,
      117.1835f, -222.6989f, 133.3077f, 366.0017f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f, 0.0000f,
      0.0000f, 0.0000f, 0.0000f, 0.0000f,
      -31.8589f, -1360.4489f, 24.6785f, 1090.4601f, 166.2095f, -755.0047f,
      1042.6797f, -398.7556f, -427.8389f, 649.8792f,
      -278.0162f, 144.9777f, -135.7040f, -1001.8986f, 703.5551f, -342.1578f,
      1054.8109f, -362.4127f, -681.6121f, 898.4576f), Array(20, 10))

    var i = expectGrad.size - 1
    while (i >= 0) {
      TestUtils.conditionFailTest(expectGrad(i).almostEqual(grads(i), 3.5) == true)
      i -= 1
    }
  }

  "Bert " should "save/load be able to work" in {
    val layer = BERT[Float](vocab = 30000,
      hiddenSize = 10,
      nBlock = 5,
      nHead = 2,
      intermediateSize = 64,
      hiddenPDrop = 0,
      attnPDrop = 0,
      maxPositionLen = 11,
      outputAllBlock = false)
    val shape = Shape(List(Shape(1, 11), Shape(1, 11), Shape(1, 11), Shape(1, 1, 1, 11)))
    layer.build(shape)

    val inputIds = Tensor[Float](Array[Float](2040f, 2001, 3958, 27227, 1029, 3958, 103,
      2001, 1037, 13997, 11510), Array(1, 11))
    val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1), Array(1, 11))
    val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Array(1, 11))
    val masks = Tensor[Float](1, 1, 1, 11).fill(1.0f)

    val input = T(inputIds.clone(), segmentIds.clone(), positionIds.clone(), masks.clone())
    val output = layer.forward(input).toTable
    val expectO = T(output[Tensor[Float]](1).clone(), output[Tensor[Float]](2).clone())
    val gradOutput_1 = Tensor[Float](1, 11, 10).rand()
    val gradO2 = Tensor[Float](1, 10)
    val gradOutput = T(gradOutput_1, gradO2)
    val gradInput = layer.backward(input, T(gradOutput_1.clone(), gradO2.clone())).toTable
    val test = (1 to gradInput.length()).map(gradInput[Tensor[Float]](_).clone())
    val expepctGradInput = T.array(test.toArray)

    val tmpFile = ZooSpecHelper.createTmpFile()
    val absPath = tmpFile.getAbsolutePath
    layer.saveModule(absPath, overWrite = true)
    val layer2 = BERT[Float](absPath, null, 11, 0, 0, false)
    val output2 = layer2.forward(T(inputIds, segmentIds, positionIds, masks)).toTable
    val gradInput3 = layer2.backward(T(inputIds, segmentIds, positionIds, masks),
      gradOutput).toTable
    for (i <- 1 to output2.length()) {
      TestUtils.conditionFailTest(output2[Tensor[Float]](i).almostEqual(expectO[Tensor[Float]](i),
        1e-8))
    }
    for (i <- 1 to gradInput.length()) {
      TestUtils.conditionFailTest(
        expepctGradInput[Tensor[Float]](i).almostEqual(gradInput3[Tensor[Float]](i), 1e-5))
    }
  }

//   TODO: uncomment this ut after we have put zoo model in a public place
//  "Bert with pretrained model " should "be able to work" in {
//    // TODO: put zoo model in a public place
//    val layer = BERT[Float]("/tmp/zoo-bert-splitprojection.model", null,
//      inputSeqLen = 11, hiddenPDrop = 0.0, attnPDrop = 0.0, true)
//
//    val inputIds = Tensor[Float](Array[Float](2040f, 2001, 3958, 27227, 1029, 3958, 103,
//      2001, 1037, 13997, 11510), Array(1, 11))
//    val segmentIds = Tensor[Float](Array[Float](0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1), Array(1, 11))
//    val positionIds = Tensor[Float](Array[Float](0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Array(1, 11))
//    val masks = Tensor[Float](1, 1, 1, 11).fill(1.0f)
//
//    val input = T(inputIds, segmentIds, positionIds, masks)
//    val output = layer.forward(input).toTable
//
//    val expectPoolOutput = Tensor[Float](Array[Float](-0.7162f, -0.1396f, 0.7603f,
//      0.3641f, -0.6005f, -0.2462f, 0.7046f, 0.0890f,
//      0.8521f, -0.9968f, 0.5658f, -0.6746f, 0.9748f, -0.6179f, 0.9423f, -0.3352f,
//      0.0830f, -0.4201f, 0.2724f, -0.5666f, 0.7584f, -0.0920f, 0.8036f, 0.2927f,
//      0.0453f, -0.8674f, -0.4508f, 0.9286f, 0.9351f, 0.8020f, -0.6272f, 0.0372f,
//      -0.9862f, -0.1660f, 0.0781f, -0.9841f, -0.1545f, -0.6641f, -0.2398f, 0.0144f,
//      -0.9033f, 0.1099f, 0.9771f, -0.8061f, 0.0059f, -0.2065f, -0.9722f, 0.0251f,
//      -0.9021f, -0.9456f, -0.6896f, -0.9564f, 0.0519f, 0.1854f, 0.1896f, 0.6364f,
//      -0.1458f, 0.0746f, -0.0153f, -0.3355f, -0.4169f, 0.0664f, 0.6940f, -0.8770f,
//      -0.9073f, -0.7581f, -0.2571f, -0.2412f, -0.0101f, -0.0291f, 0.7521f, 0.1513f,
//      0.6677f, -0.8569f, -0.8606f, 0.2286f, -0.3572f, 0.9978f, 0.3681f, -0.9787f,
//      -0.0825f, -0.7081f, 0.3884f, 0.8268f, -0.6925f, -0.9885f, -0.0306f, -0.1499f,
//      -0.9875f, 0.1070f, 0.1838f, -0.0348f, 0.1364f, 0.4469f, -0.2244f, -0.0421f,
//      -0.0109f, 0.3887f, -0.1453f, 0.1406f, -0.0316f, -0.0859f, -0.0602f, -0.2716f,
//      0.0345f, -0.0536f, -0.1663f, -0.2223f, -0.5041f, 0.3877f, 0.3318f, -0.1905f,
//      0.0250f, -0.9412f, 0.2981f, -0.1303f, -0.9602f, -0.4054f, -0.9871f, 0.7988f,
//      -0.0876f, -0.4186f, 0.9571f, 0.7023f, 0.1317f, -0.0598f, 0.9564f, -0.9982f,
//      0.6961f, 0.0997f, 0.6105f, -0.0848f, -0.9721f, -0.9648f, 0.2606f, 0.8880f,
//      0.0675f, 0.9830f, -0.0328f, 0.9508f, 0.7032f, 0.3969f, -0.9032f, -0.4663f,
//      -0.1569f, 0.1598f, -0.6167f, 0.3050f, 0.3721f, -0.8162f, -0.0270f, -0.1218f,
//      0.7294f, -0.9266f, -0.2046f, 0.9405f, 0.6775f, 0.8277f, 0.7095f, -0.0222f,
//      -0.1586f, 0.8070f, -0.1977f, 0.2968f, 0.2671f, 0.0159f, -0.3666f, 0.1654f,
//      -0.6356f, 0.8115f, 0.1022f, -0.0277f, 0.4317f, -0.9800f, -0.2187f, 0.2425f,
//      0.9773f, 0.6488f, 0.1156f, -0.8433f, -0.1829f, -0.5950f, -0.9516f, 0.9792f,
//      -0.0946f, 0.1392f, 0.4140f, -0.5860f, -0.7811f, -0.1265f, 0.5148f, 0.5665f,
//      -0.7989f, 0.0579f, -0.2782f, -0.0707f, 0.1247f, 0.3033f, -0.1486f, -0.3969f,
//      -0.1184f, 0.9160f, 0.7994f, 0.6171f, -0.5199f, -0.1578f, -0.9006f, -0.3371f,
//      0.1465f, 0.0052f, -0.0266f, 0.9798f, -0.0686f, 0.0602f, -0.8204f, -0.9792f,
//      0.1078f, -0.7781f, -0.0333f, -0.3414f, 0.0418f, 0.3171f, -0.7073f, 0.1727f,
//      -0.9042f, -0.8144f, 0.0733f, -0.0641f, 0.2665f, -0.0541f, 0.3952f, -0.6823f,
//      -0.4333f, 0.5374f, 0.9017f, 0.7004f, -0.7999f, 0.6353f, -0.1725f, 0.7958f,
//      -0.2901f, 0.9769f, -0.6294f, -0.1802f, -0.9378f, 0.2437f, -0.6200f, 0.6196f,
//      -0.1076f, -0.7601f, -0.7550f, 0.4579f, -0.0830f, 0.6436f, -0.0972f, 0.8839f,
//      -0.6791f, -0.9578f, -0.3163f, 0.4027f, -0.9870f, -0.6541f, 0.1541f, -0.5528f,
//      -0.2818f, 0.0063f, -0.9724f, 0.6324f, 0.1280f, 0.9215f, 0.1273f, -0.5899f,
//      0.5715f, -0.9524f, 0.0339f, -0.2239f, 0.8126f, -0.0554f, -0.9408f, 0.3213f,
//      0.4896f, 0.0629f, 0.7366f, 0.9008f, 0.9069f, 0.9719f, 0.8679f, 0.5541f,
//      -0.2049f, 0.1948f, 0.9966f, 0.6645f, -0.9866f, -0.9136f, -0.2357f, 0.4273f,
//      -0.9980f, -0.0674f, -0.1421f, -0.9326f, -0.6956f, 0.9571f, 0.9187f, -0.9966f,
//      0.7822f, 0.8609f, -0.4010f, -0.8146f, -0.0580f, 0.9602f, 0.1276f, 0.2033f,
//      -0.1566f, 0.4025f, -0.0706f, -0.7124f, 0.5992f, 0.5017f, -0.5854f, 0.0679f,
//      -0.0444f, -0.9526f, -0.3486f, -0.0046f, -0.2216f, -0.9666f, -0.1915f, -0.8234f,
//      0.0790f, -0.0616f, 0.0841f, -0.7505f, 0.0087f, -0.6945f, 0.4681f, 0.4427f,
//      -0.7811f, -0.4501f, 0.6757f, -0.7121f, 0.5318f, -0.9615f, 0.9654f, -0.1700f,
//      -0.9108f, 0.9955f, -0.0542f, -0.9140f, 0.0220f, -0.0412f, -0.3037f, 0.9941f,
//      -0.1090f, -0.9820f, -0.3511f, -0.3380f, -0.0779f, 0.0112f, 0.9919f, -0.0013f,
//      0.6381f, 0.5495f, 0.9873f, -0.9844f, -0.5295f, -0.7156f, -0.9563f, 0.9685f,
//      0.9571f, 0.0137f, -0.4769f, 0.0336f, 0.3625f, 0.1125f, -0.8594f, 0.2999f,
//      0.5331f, -0.1464f, 0.8701f, -0.5283f, -0.2819f, 0.1792f, 0.6559f, 0.7553f,
//      -0.7151f, 0.1020f, -0.0604f, -0.0285f, -0.2009f, -0.0457f, -0.9241f, -0.2874f,
//      0.9940f, 0.3557f, -0.5637f, 0.3754f, -0.1321f, 0.0472f, 0.0158f, 0.2659f,
//      -0.0973f, -0.7388f, -0.6294f, -0.7506f, -0.9889f, 0.7080f, 0.0880f, -0.0248f,
//      0.8815f, 0.3149f, 0.0668f, -0.4228f, -0.8850f, 0.0654f, 0.1866f, -0.8231f,
//      0.9515f, -0.1010f, 0.3055f, 0.6091f, 0.8471f, -0.2643f, -0.2524f, 0.0051f,
//      -0.9161f, 0.1817f, -0.9544f, 0.9664f, -0.9283f, 0.0641f, 0.2203f, -0.4566f,
//      0.9939f, -0.3684f, 0.3610f, -0.3358f, 0.7058f, 0.0779f, -0.5924f, -0.3633f,
//      0.0304f, 0.8327f, 0.0060f, 0.1495f, -0.9548f, -0.8065f, -0.7363f, -0.8592f,
//      -0.9867f, 0.7629f, 0.2969f, 0.1757f, 0.2586f, -0.1554f, -0.5450f, -0.3296f,
//      0.0075f, -0.9652f, 0.8086f, -0.1231f, 0.1124f, -0.2777f, 0.4091f, -0.7615f,
//      0.7235f, 0.6044f, 0.2692f, -0.2067f, -0.7993f, 0.1972f, -0.2400f, 0.8271f,
//      -0.0676f, 0.9969f, 0.2129f, -0.3808f, 0.6144f, 0.4426f, -0.1054f, 0.1481f,
//      -0.7204f, 0.1953f, 0.5176f, 0.9203f, -0.5080f, -0.0119f, 0.0775f, -0.2754f,
//      -0.5738f, 0.8438f, -0.3832f, 0.0354f, -0.1252f, 0.0627f, 0.8787f, 0.0781f,
//      0.0409f, -0.4946f, -0.0855f, -0.1822f, -0.2237f, 0.9882f, 0.1373f, -0.2795f,
//      -0.9899f, 0.7308f, -0.7784f, 0.7151f, 0.9118f, -0.8336f, -0.0133f, -0.2039f,
//      -0.1394f, 0.2967f, -0.1437f, -0.1720f, 0.1148f, 0.1680f, 0.9750f, -0.3740f,
//      -0.9799f, -0.6423f, 0.0846f, -0.9437f, 0.2282f, -0.2336f, 0.0390f, -0.2100f,
//      0.5036f, 0.6197f, -0.0862f, -0.9763f, -0.0167f, -0.0524f, 0.9740f, 0.1433f,
//      -0.3203f, -0.8910f, -0.8981f, -0.3564f, 0.8119f, -0.9449f, 0.9762f, -0.9524f,
//      -0.3477f, 0.9853f, 0.2716f, -0.9356f, -0.0130f, -0.3620f, 0.1561f, -0.1850f,
//      0.3515f, -0.9542f, -0.1533f, -0.1714f, 0.2145f, 0.0633f, 0.2303f, 0.7615f,
//      0.1269f, -0.2645f, -0.3332f, -0.0549f, 0.3007f, 0.4099f, -0.0848f, -0.0151f,
//      -0.1012f, 0.0073f, -0.9460f, -0.1528f, -0.1057f, -0.2738f, 0.6139f, -0.9966f,
//      -0.6367f, -0.7447f, -0.1497f, 0.8077f, -0.1831f, -0.6210f, -0.8092f, 0.8853f,
//      0.8814f, 0.6907f, -0.1110f, 0.8732f, -0.7302f, 0.0512f, -0.1806f, 0.2345f,
//      0.7230f, 0.6315f, -0.1467f, 0.9984f, 0.1652f, -0.0225f, -0.7866f, 0.1438f,
//      -0.1376f, 0.9274f, -0.5747f, -0.9553f, 0.1266f, -0.1419f, -0.7931f, 0.1151f,
//      0.0198f, -0.3491f, 0.4269f, 0.8990f, 0.7312f, -0.3052f, 0.2512f, -0.0957f,
//      0.0593f, 0.1259f, -0.8737f, 0.9860f, 0.2696f, 0.4766f, 0.3463f, -0.0315f,
//      0.9666f, 0.0141f, 0.6121f, -0.0339f, 0.9844f, 0.1921f, -0.9012f, 0.1097f,
//      -0.9560f, -0.2892f, -0.8855f, 0.2145f, 0.0726f, 0.8682f, -0.1482f, 0.9529f,
//      0.8500f, -0.0611f, 0.7121f, 0.8382f, -0.0007f, -0.9559f, -0.9868f, -0.9916f,
//      -0.2818f, -0.2869f, 0.0582f, 0.2519f, 0.0024f, 0.1225f, 0.2525f, -0.9693f,
//      0.9053f, 0.2988f, -0.6687f, 0.9702f, -0.5648f, -0.0998f, 0.3368f, -0.9773f,
//      -0.9023f, -0.2172f, -0.1050f, 0.5856f, 0.2271f, 0.8620f, -0.0993f, -0.2223f,
//      -0.2049f, 0.4892f, -0.5222f, -0.9908f, 0.3641f, 0.8315f, -0.6809f, 0.9588f,
//      -0.4072f, -0.3041f, 0.7473f, 0.8007f, 0.7079f, 0.7385f, 0.4648f, 0.0785f,
//      0.6499f, 0.9205f, 0.7790f, 0.9819f, 0.5908f, 0.5459f, 0.9206f, 0.2609f,
//      0.5047f, -0.9116f, 0.2087f, -0.2924f, -0.0403f, 0.1722f, -0.0398f, -0.8534f,
//      0.2932f, -0.0467f, 0.4758f, -0.2525f, 0.0883f, -0.2131f, 0.0019f, -0.5476f,
//      0.0264f, 0.4317f, 0.1654f, 0.9330f, -0.2348f, -0.1628f, 0.1235f, -0.0188f,
//      0.7218f, -0.9457f, 0.6322f, 0.0025f, 0.7549f, -0.5926f, -0.1748f, 0.6003f,
//      -0.2664f, -0.0785f, -0.0902f, -0.5837f, 0.6162f, 0.2644f, -0.1987f, -0.2380f,
//      0.0994f, 0.1211f, 0.3593f, 0.2705f, 0.8306f, 0.3293f, 0.0518f, 0.2511f,
//      -0.0081f, -0.9695f, 0.0637f, 0.7831f, -0.2514f, 0.6340f, -0.3967f, 0.8890f,
//      -0.9158f, -0.2154f, -0.2064f, -0.6473f, -0.2937f, 0.6980f, 0.3257f, 0.9740f,
//      -0.7701f, 0.7724f, 0.5081f, 0.7332f, 0.3277f, 0.5873f, -0.4213f, 0.8806f), Array(1, 768))
//    TestUtils.conditionFailTest(output[Tensor[Float]](13).almostEqual(expectPoolOutput, 2e-4))
//
//    val gradPoolOutput = Tensor[Float](Array[Float](99.0f, 78.0f, 61.0f, 16.0f, 73.0f,
//      8.0f, 62.0f, 27.0f, 30.0f, 80.0f, 7.0f, 76.0f, 15.0f, 53.0f,
//      80.0f, 27.0f, 44.0f, 77.0f, 75.0f, 65.0f, 47.0f, 30.0f, 84.0f, 86.0f, 18.0f,
//      9.0f, 41.0f, 62.0f,
//      1.0f, 82.0f, 16.0f, 78.0f, 5.0f, 58.0f, 0.0f, 80.0f, 4.0f, 36.0f, 51.0f, 27.0f,
//      31.0f, 2.0f,
//      68.0f, 38.0f, 83.0f, 19.0f, 18.0f, 7.0f, 30.0f, 62.0f, 11.0f, 67.0f, 65.0f, 55.0f,
//      3.0f, 91.0f,
//      78.0f, 27.0f, 29.0f, 33.0f, 89.0f, 85.0f, 7.0f, 16.0f, 94.0f, 14.0f, 90.0f, 31.0f,
//      9.0f, 38.0f,
//      47.0f, 16.0f, 5.0f, 34.0f, 45.0f, 59.0f, 24.0f, 13.0f, 31.0f, 32.0f, 76.0f, 44.0f,
//      5.0f, 14.0f,
//      47.0f, 94.0f, 82.0f, 0.0f, 7.0f, 86.0f, 16.0f, 64.0f, 8.0f, 90.0f, 44.0f, 37.0f,
//      94.0f, 75.0f,
//      5.0f, 22.0f, 52.0f, 69.0f, 82.0f, 60.0f, 91.0f, 29.0f, 88.0f, 97.0f, 92.0f, 79.0f,
//      70.0f, 35.0f,
//      20.0f, 49.0f, 72.0f, 32.0f, 82.0f, 13.0f, 92.0f, 18.0f, 52.0f, 81.0f, 22.0f, 58.0f,
//      83.0f, 92.0f,
//      83.0f, 49.0f, 4.0f, 82.0f, 36.0f, 41.0f, 20.0f, 32.0f, 10.0f, 31.0f, 15.0f, 22.0f,
//      70.0f, 9.0f,
//      63.0f, 94.0f, 14.0f, 66.0f, 57.0f, 19.0f, 64.0f, 8.0f, 8.0f, 71.0f, 12.0f, 20.0f,
//      59.0f, 72.0f,
//      74.0f, 86.0f, 72.0f, 32.0f, 15.0f, 69.0f, 35.0f, 62.0f, 43.0f, 0.0f, 2.0f, 2.0f,
//      91.0f, 65.0f,
//      45.0f, 87.0f, 1.0f, 23.0f, 50.0f, 86.0f, 19.0f, 54.0f, 24.0f, 64.0f, 77.0f, 73.0f,
//      1.0f, 9.0f,
//      64.0f, 23.0f, 39.0f, 68.0f, 81.0f, 91.0f, 36.0f, 97.0f, 87.0f, 69.0f, 36.0f, 18.0f,
//      34.0f, 30.0f,
//      77.0f, 97.0f, 35.0f, 29.0f, 1.0f, 82.0f, 20.0f, 0.0f, 34.0f, 78.0f, 51.0f, 30.0f,
//      40.0f, 74.0f,
//      69.0f, 79.0f, 53.0f, 19.0f, 46.0f, 26.0f, 85.0f, 89.0f, 57.0f, 17.0f, 94.0f, 64.0f,
//      28.0f, 8.0f,
//      14.0f, 64.0f, 31.0f, 45.0f, 5.0f, 26.0f, 41.0f, 83.0f, 28.0f, 75.0f, 35.0f, 83.0f,
//      55.0f, 3.0f,
//      23.0f, 3.0f, 95.0f, 2.0f, 54.0f, 93.0f, 38.0f, 18.0f, 71.0f, 64.0f, 35.0f, 37.0f,
//      1.0f, 30.0f,
//      91.0f, 31.0f, 93.0f, 6.0f, 7.0f, 35.0f, 78.0f, 9.0f, 7.0f, 89.0f, 90.0f, 54.0f, 31.0f,
//      14.0f,
//      4.0f, 85.0f, 74.0f, 68.0f, 96.0f, 81.0f, 67.0f, 79.0f, 17.0f, 87.0f, 11.0f, 30.0f,
//      26.0f, 8.0f,
//      51.0f, 61.0f, 84.0f, 52.0f, 74.0f, 25.0f, 9.0f, 31.0f, 39.0f, 59.0f, 22.0f, 80.0f,
//      58.0f, 44.0f,
//      15.0f, 15.0f, 95.0f, 77.0f, 81.0f, 99.0f, 57.0f, 46.0f, 19.0f, 45.0f, 92.0f, 83.0f,
//      51.0f, 85.0f,
//      16.0f, 96.0f, 26.0f, 94.0f, 30.0f, 93.0f, 78.0f, 14.0f, 98.0f, 8.0f, 57.0f, 81.0f,
//      40.0f, 59.0f,
//      58.0f, 10.0f, 6.0f, 58.0f, 73.0f, 54.0f, 93.0f, 33.0f, 24.0f, 18.0f, 20.0f, 46.0f,
//      89.0f, 5.0f,
//      76.0f, 45.0f, 79.0f, 24.0f, 38.0f, 12.0f, 1.0f, 5.0f, 17.0f, 11.0f, 88.0f, 56.0f,
//      83.0f, 26.0f,
//      42.0f, 57.0f, 87.0f, 48.0f, 8.0f, 10.0f, 39.0f, 3.0f, 44.0f, 68.0f, 94.0f, 55.0f,
//      94.0f, 31.0f,
//      26.0f, 96.0f, 13.0f, 76.0f, 54.0f, 83.0f, 10.0f, 33.0f, 97.0f, 64.0f, 36.0f, 60.0f,
//      0.0f, 72.0f,
//      24.0f, 26.0f, 5.0f, 29.0f, 63.0f, 18.0f, 55.0f, 76.0f, 6.0f, 45.0f, 85.0f, 46.0f,
//      58.0f, 62.0f,
//      2.0f, 22.0f, 95.0f, 76.0f, 55.0f, 38.0f, 5.0f, 49.0f, 58.0f, 15.0f, 86.0f, 48.0f,
//      18.0f, 59.0f,
//      95.0f, 53.0f, 18.0f, 36.0f, 36.0f, 67.0f, 9.0f, 47.0f, 0.0f, 72.0f, 81.0f, 64.0f,
//      25.0f, 60.0f,
//      36.0f, 78.0f, 5.0f, 13.0f, 73.0f, 69.0f, 77.0f, 33.0f, 64.0f, 3.0f, 94.0f, 45.0f,
//      70.0f, 58.0f,
//      52.0f, 99.0f, 67.0f, 11.0f, 65.0f, 21.0f, 8.0f, 31.0f, 47.0f, 77.0f, 78.0f, 82.0f,
//      66.0f, 98.0f,
//      98.0f, 84.0f, 20.0f, 59.0f, 57.0f, 33.0f, 53.0f, 96.0f, 1.0f, 70.0f, 90.0f, 78.0f,
//      17.0f, 26.0f,
//      32.0f, 31.0f, 89.0f, 49.0f, 69.0f, 41.0f, 76.0f, 79.0f, 38.0f, 79.0f, 81.0f, 38.0f,
//      43.0f, 28.0f,
//      17.0f, 16.0f, 73.0f, 54.0f, 45.0f, 34.0f, 90.0f, 67.0f, 69.0f, 70.0f, 90.0f, 18.0f,
//      75.0f, 94.0f,
//      29.0f, 33.0f, 94.0f, 93.0f, 29.0f, 47.0f, 55.0f, 21.0f, 70.0f, 16.0f, 15.0f, 83.0f,
//      91.0f, 70.0f,
//      71.0f, 41.0f, 13.0f, 61.0f, 0.0f, 46.0f, 65.0f, 86.0f, 80.0f, 48.0f, 3.0f, 77.0f,
//      60.0f, 16.0f,
//      11.0f, 1.0f, 97.0f, 57.0f, 9.0f, 50.0f, 16.0f, 61.0f, 65.0f, 30.0f, 69.0f, 54.0f,
//      46.0f, 85.0f,
//      10.0f, 58.0f, 97.0f, 82.0f, 42.0f, 38.0f, 34.0f, 18.0f, 43.0f, 56.0f, 76.0f, 7.0f,
//      47.0f, 62.0f,
//      93.0f, 51.0f, 24.0f, 36.0f, 15.0f, 64.0f, 49.0f, 0.0f, 6.0f, 26.0f, 4.0f, 81.0f,
//      12.0f, 95.0f,
//      63.0f, 98.0f, 61.0f, 94.0f, 91.0f, 75.0f, 62.0f, 51.0f, 2.0f, 96.0f, 19.0f, 7.0f,
//      76.0f, 66.0f,
//      56.0f, 0.0f, 57.0f, 31.0f, 86.0f, 74.0f, 37.0f, 64.0f, 51.0f, 12.0f, 59.0f, 25.0f,
//      39.0f, 33.0f,
//      24.0f, 7.0f, 68.0f, 81.0f, 69.0f, 69.0f, 81.0f, 20.0f, 86.0f, 1.0f, 74.0f, 70.0f,
//      48.0f, 55.0f,
//      51.0f, 97.0f, 39.0f, 0.0f, 30.0f, 42.0f, 87.0f, 71.0f, 4.0f, 62.0f, 25.0f, 17.0f,
//      50.0f, 63.0f,
//      17.0f, 73.0f, 1.0f, 22.0f, 33.0f, 27.0f, 82.0f, 29.0f, 2.0f, 62.0f, 55.0f, 10.0f,
//      90.0f, 53.0f,
//      88.0f, 54.0f, 9.0f, 75.0f, 49.0f, 16.0f, 66.0f, 73.0f, 29.0f, 60.0f, 1.0f, 77.0f,
//      21.0f, 25.0f,
//      60.0f, 26.0f, 14.0f, 68.0f, 31.0f, 16.0f, 50.0f, 70.0f, 92.0f, 16.0f, 67.0f, 50.0f,
//      76.0f, 53.0f,
//      61.0f, 70.0f, 49.0f, 69.0f, 20.0f, 6.0f, 21.0f, 38.0f, 94.0f, 13.0f, 92.0f, 15.0f,
//      63.0f, 46.0f,
//      19.0f, 99.0f, 18.0f, 61.0f, 22.0f, 46.0f, 60.0f, 86.0f, 15.0f, 33.0f, 93.0f, 78.0f,
//      82.0f, 18.0f,
//      12.0f, 40.0f, 62.0f, 36.0f, 7.0f, 88.0f, 58.0f, 78.0f, 58.0f, 35.0f, 10.0f, 60.0f,
//      30.0f, 88.0f,
//      28.0f, 76.0f, 98.0f, 93.0f, 89.0f, 42.0f, 96.0f, 42.0f, 91.0f, 45.0f, 48.0f, 35.0f,
//      25.0f, 20.0f,
//      94.0f, 38.0f, 73.0f, 56.0f, 76.0f, 5.0f, 64.0f, 97.0f, 9.0f, 52.0f, 46.0f, 15.0f,
//      82.0f, 16.0f,
//      3.0f, 21.0f, 67.0f, 66.0f, 92.0f, 83.0f, 36.0f, 4.0f, 33.0f, 54.0f, 27.0f, 47.0f,
//      8.0f, 96.0f,
//      26.0f, 89.0f, 74.0f, 16.0f, 22.0f, 68.0f, 67.0f, 36.0f, 85.0f, 16.0f, 1.0f, 96.0f,
//      91.0f, 41.0f,
//      72.0f, 95.0f, 73.0f, 31.0f, 86.0f, 3.0f, 52.0f, 6.0f, 62.0f, 29.0f, 30.0f, 90.0f),
//      Array(1, 768))
//    val gradO2 = Tensor[Float](1, 11, 768)
//    val gradOutput = T.array(Array.fill[Tensor[Float]](12)(gradO2) :+ gradPoolOutput)
//    val gradInput = layer.backward(input, gradOutput).toTable
//
//    val gradients = layer.parameters()._2
//    TestUtils.conditionFailTest(Math.abs(gradients(2).apply(Array(1, 1)) - 151.8128) < 0.005)
//    TestUtils.conditionFailTest(Math.abs(gradients(2).apply(Array(1, 2)) - (-249.5875)) < 0.015)
//    TestUtils.conditionFailTest(Math.abs(gradients(0).apply(Array(1, 1)) - (-16.2635)) < 0.006)
//    TestUtils.conditionFailTest(Math.abs(gradients(0).apply(Array(1, 2)) - (-18.1143)) < 0.016)
//    TestUtils.conditionFailTest(Math.abs(gradients(5).apply(Array(1, 1)) - 0.0743) < 5e-5)
//    TestUtils.conditionFailTest(Math.abs(gradients(5).apply(Array(1, 2)) - 0.6531) < 7e-5)
//  }

  "Bert gelu" should "be able to generate correct result" in {
    val layer = BERT[Float](vocab = 100,
      hiddenSize = 10,
      nBlock = 3,
      nHead = 2,
      intermediateSize = 64,
      hiddenPDrop = 0.1,
      attnPDrop = 0.1,
      maxPositionLen = 10,
      outputAllBlock = false)

    val xValue = Tensor[Float](Array[Float](2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f,
      10f, 11f, 12f, 13f, 14f, 15f, 16f, 17f), Array(2, 2, 4))
    val gradOValue = Tensor[Float](Array[Float](20f, 21f, 22f, 23f,
      24f, 25f, 26f, 27f,
      28f, 29f, 30f, 31f,
      32f, 33f, 34f, 35f), Array(2, 2, 4))
    val x = Variable[Float](inputShape = Shape(2, 4))

    val y = layer.gelu(x)
    val model = Model[Float](input = x, output = y)

    val output2 = model.forward(xValue).toTensor[Float]
    val expectOutput = Tensor[Float](Array[Float](1.9545f, 2.9960f, 3.9999f, 5.0000f,
      6.0000f, 7.0000f, 8.0000f, 9.0000f,
      10.0000f, 11.0000f, 12.0000f, 13.0000f,
      14.0000f, 15.0000f, 16.0000f, 17.0000f), Array(2, 2, 4))
    TestUtils.conditionFailTest(output2.almostEqual(expectOutput, 5e-5) == true)

    val expectGradInput = Tensor[Float](Array[Float](21.7046f, 21.2509f, 22.0111f, 23.0002f,
      24.0000f, 25.0000f, 26.0000f, 27.0000f,

      28.0000f, 29.0000f, 30.0000f, 31.0000f,
      32.0000f, 33.0000f, 34.0000f, 35.0000f), Array(2, 2, 4))
    val gradInput = model.backward(xValue, gradOValue).toTensor[Float]
    TestUtils.conditionFailTest(gradInput.almostEqual(expectGradInput, 5e-5) == true)
  }
}

class BERTSerialTest extends ModuleSerializationTest {
  // remove the test since it's duplicate with "Bert " should "save/load be able to work"
  override def test(): Unit = {
  }
}
